diff --git a/frontend/src/components/MessageWithCitations.tsx b/frontend/src/components/MessageWithCitations.tsx
index 692ee3e..4b1b409 100644
--- a/frontend/src/components/MessageWithCitations.tsx
+++ b/frontend/src/components/MessageWithCitations.tsx
@@ -130,6 +130,14 @@ function CitationBadge({
/**
* Process text to replace [cite: N] markers with citation badges
* ADR-3.4.1: Uses [cite: N] format to avoid false positives with [N]
+ *
+ * Citation Index Mapping:
+ * - LLM outputs: [cite: 1], [cite: 2], ..., [cite: N] (1-based, user-friendly)
+ * - Backend emits: cite-0, cite-1, ..., cite-(N-1) (0-based, internal)
+ * - Frontend displays: 1, 2, ..., N (1-based, user-friendly)
+ *
+ * Example: For 9 chunks
+ * - LLM: [cite: 9] → citationId: cite-8 → Display: 9
*/
function processTextWithCitations(
text: string,
@@ -149,15 +157,20 @@ function processTextWithCitations(
result.push(text.slice(lastIndex, match.index));
}
- const citationIndex = parseInt(match[1], 10);
- const citationId = `cite-${citationIndex}`;
+ // Parse the 1-based citation number from LLM output (e.g., [cite: 9])
+ const llmCitationNumber = parseInt(match[1], 10);
+
+ // Convert to 0-based citation ID for lookup (e.g., cite-8)
+ // LLM uses 1-based: [cite: 1] through [cite: N]
+ // Backend emits 0-based: cite-0 through cite-(N-1)
+ const citationId = `cite-${llmCitationNumber - 1}`;
const citation = citations[citationId];
if (citation) {
result.push(
@@ -166,7 +179,7 @@ function processTextWithCitations(
// Graceful degradation for invalid citations
result.push(
- [{citationIndex}]
+ [{llmCitationNumber}]
);
}
diff --git a/frontend/src/components/__tests__/CitationDeepDive.test.tsx b/frontend/src/components/__tests__/CitationDeepDive.test.tsx
index b2f39d1..38bbfcc 100644
--- a/frontend/src/components/__tests__/CitationDeepDive.test.tsx
+++ b/frontend/src/components/__tests__/CitationDeepDive.test.tsx
@@ -53,13 +53,13 @@ describe('Citation Deep-Dive Integration', () => {
render(
);
- const citationBadge = screen.getByRole('button', { name: /Citation 0/i });
+ const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
await user.click(citationBadge);
expect(handleCitationClick).toHaveBeenCalledWith(mockCitations['cite-0']);
@@ -77,17 +77,18 @@ describe('Citation Deep-Dive Integration', () => {
render(
);
- const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
+ const citationBadge = screen.getByRole('button', { name: /Citation 2/i });
await user.click(citationBadge);
expect(receivedCitation).not.toBeNull();
expect(receivedCitation!.documentId).toBe('doc-uuid-456');
+ expect(receivedCitation!.id).toBe('cite-1');
});
});
@@ -110,13 +111,13 @@ describe('Citation Deep-Dive Integration', () => {
render(
);
- const citationBadge = screen.getByRole('button', { name: /Citation 0/i });
+ const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
await user.click(citationBadge);
// The handler is called - workspace page validates and shows error
diff --git a/frontend/src/components/__tests__/MessageWithCitations.test.tsx b/frontend/src/components/__tests__/MessageWithCitations.test.tsx
index aa77542..dc5c147 100644
--- a/frontend/src/components/__tests__/MessageWithCitations.test.tsx
+++ b/frontend/src/components/__tests__/MessageWithCitations.test.tsx
@@ -11,9 +11,11 @@ import { MessageWithCitations } from '../MessageWithCitations';
import type { CitationMap } from '@/types/citation';
describe('MessageWithCitations', () => {
+ // Mock data uses 0-based citation IDs to match backend format
+ // LLM outputs [cite: 1], [cite: 2] → Maps to cite-0, cite-1
const mockCitations: CitationMap = {
- 'cite-1': {
- id: 'cite-1',
+ 'cite-0': {
+ id: 'cite-0',
title: 'intro.pdf',
mediaType: 'text/plain',
chunkIndex: 0,
@@ -21,8 +23,8 @@ describe('MessageWithCitations', () => {
preview: 'Python is a powerful programming language used worldwide...',
documentId: 'doc-uuid-1',
},
- 'cite-2': {
- id: 'cite-2',
+ 'cite-1': {
+ id: 'cite-1',
title: 'advanced.pdf',
mediaType: 'text/plain',
chunkIndex: 5,
@@ -152,7 +154,7 @@ describe('MessageWithCitations', () => {
const button = screen.getByRole('button');
await user.click(button);
- expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-1']);
+ expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-0']);
});
});
diff --git a/ragitect/agents/rag/streaming.py b/ragitect/agents/rag/streaming.py
index 14d5811..2bda85a 100644
--- a/ragitect/agents/rag/streaming.py
+++ b/ragitect/agents/rag/streaming.py
@@ -68,7 +68,7 @@ class LangGraphToAISDKAdapter:
[cite: N] markers split across token chunks (20-char lookahead).
Reference Implementation:
- - See git history for CitationStreamParser (pre-Story 4.3)
+ - See git history for CitationStreamParser
- Production examples: docs/research/2025-12-28-ENG-4.3-langgraph-streaming-adapter.md
Example:
@@ -98,11 +98,38 @@ def _build_citations_from_context(
) -> list[Citation]:
"""Build Citation objects from context chunks.
+ Node-Level Streaming Simplification:
+ ================================================
+ This method emits ALL context chunks as citations upfront, not just
+ the ones referenced in the LLM response. This is intentional because:
+
+ 1. **No Buffer Needed**: Node-level streaming (stream_mode="updates")
+ provides complete text per node. Citation markers like [cite: 1]
+ are never split across chunks, unlike token-level streaming.
+
+ 2. **Indexing Scheme**:
+ - Internal (Python): 0-based (enumerate index)
+ - LLM Prompt: 1-based ([Chunk 1], [Chunk 2] in context)
+ - SSE sourceId: 0-based (cite-0, cite-1)
+ - Frontend display: 1-based (Citation 1, 2, etc.)
+
+ 3. **Frontend Compatibility**: AI SDK useChat() receives all source
+ documents before text-delta, enabling immediate citation rendering.
+
+ 4. **Future Enhancement**: Detect only cited sources via regex:
+ ```python
+ pattern = re.compile(r"\\[cite:\\s*(\\d+)\\]")
+ for match in pattern.finditer(full_text):
+ cite_idx = int(match.group(1)) - 1 # 1-based to 0-based
+ ```
+ This would filter to only actually-cited sources.
+
Args:
- context_chunks: List of context chunk dicts from merge_context node
+ context_chunks: List of context chunk dicts from merge_context node.
+ Expected keys: document_id, title, chunk_index, score/rerank_score, content
Returns:
- List of Citation objects ready for streaming
+ List of Citation objects ready for streaming as source-document events.
"""
citations = []
for idx, chunk in enumerate(context_chunks):
diff --git a/ragitect/api/v1/chat.py b/ragitect/api/v1/chat.py
index 0876f66..f664f9e 100644
--- a/ragitect/api/v1/chat.py
+++ b/ragitect/api/v1/chat.py
@@ -11,7 +11,6 @@
import json
import logging
-import re
import uuid
from collections.abc import AsyncGenerator
from uuid import UUID
@@ -25,7 +24,6 @@
from ragitect.agents.rag import build_rag_graph
from ragitect.agents.rag.state import RAGState
from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
-from ragitect.api.schemas.chat import Citation
from ragitect.prompts.rag_prompts import build_rag_system_prompt
from ragitect.services.config import EmbeddingConfig
from ragitect.services.database.connection import get_async_session
@@ -111,174 +109,6 @@ async def format_sse_stream(
yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n"
-def build_citation_metadata(context_chunks: list[dict]) -> list[Citation]:
- """Build citation metadata from context chunks.
-
- NOTE: Prompt engineering for [N] format
- This function just prepares metadata for frontend consumption.
-
- Args:
- context_chunks: Chunks returned from retrieve_context()
-
- Returns:
- List of Citation objects for streaming
- """
- citations = []
- for i, chunk in enumerate(context_chunks):
- citations.append(
- Citation.from_context_chunk(
- index=i + 1, # 1-based index
- document_id=str(chunk.get("document_id", "")),
- document_name=chunk.get("document_name", "Unknown"),
- chunk_index=chunk.get("chunk_index", 0),
- similarity=chunk.get("rerank_score") or chunk.get("similarity", 0.0),
- content=chunk.get("content", ""),
- )
- )
- return citations
-
-
-class CitationStreamParser:
- """Stateful parser for detecting citations across chunk boundaries.
-
- ADR Decision: Real-time regex streaming with cross-chunk buffering.
- Handles edge case where LLM outputs '[cite:' in one chunk and '0]' in next.
-
- ADR-3.4.1: Citation format changed from [N] to [cite: N] to avoid
- false positives with markdown lists and array indices.
- """
-
- def __init__(self, citations: list[Citation]):
- """Initialize parser with available citations.
-
- Args:
- citations: Pre-built citation metadata from context chunks
- """
- self.citations = citations
- self.buffer = "" # Buffer for partial citation markers
- self.emitted_ids: set[str] = set() # Track which citations already emitted
- # ADR-3.4.1: Updated pattern from [N] to [cite: N] format
- self.pattern = re.compile(r"\[cite:\s*(\d+)\]")
-
- def parse_chunk(self, chunk: str) -> tuple[str, list[Citation]]:
- """Parse chunk and detect citation markers.
-
- Args:
- chunk: New text chunk from LLM stream
-
- Returns:
- Tuple of (text_to_emit, new_citations_found)
- """
- # Add chunk to buffer
- self.buffer += chunk
-
- # Find all complete citation markers in buffer
- new_citations = []
- for match in self.pattern.finditer(self.buffer):
- cite_idx = int(match.group(1))
- cite_id = f"cite-{cite_idx}"
-
- # Validate citation index (ADR: Hallucination Handling)
- # 1-based index check
- if cite_idx < 1 or cite_idx > len(self.citations):
- logger.warning(
- "LLM cited non-existent source [%d] (only %d chunks available)",
- cite_idx,
- len(self.citations),
- )
- continue # Graceful degradation - skip invalid citation
-
- # Emit each citation only once
- if cite_id not in self.emitted_ids:
- # Map 1-based index to 0-based list
- new_citations.append(self.citations[cite_idx - 1])
- self.emitted_ids.add(cite_id)
-
- # Emit text, but keep last 20 chars in buffer for partial markers
- # Max citation marker length: "[cite: 9999]" = 12 chars, buffer 20 for safety
- if len(self.buffer) > 20:
- text_to_emit = self.buffer[:-20]
- self.buffer = self.buffer[-20:]
- else:
- text_to_emit = ""
-
- return text_to_emit, new_citations
-
- def flush(self) -> str:
- """Flush remaining buffer at end of stream.
-
- Returns:
- Any remaining text in the buffer
- """
- remaining = self.buffer
- self.buffer = ""
- return remaining
-
-
-async def format_sse_stream_with_citations(
- chunks: AsyncGenerator[str, None],
- citations: list[Citation],
-) -> AsyncGenerator[str, None]:
- """Format LLM chunks with AI SDK UI Message Stream Protocol v1 + citations.
-
- ADR: Real-time regex streaming with cross-chunk buffering.
- Emits citations as 'source-document' parts for AI SDK useChat.
-
- Args:
- chunks: LLM token stream
- citations: Pre-built citation metadata
-
- Yields:
- SSE formatted messages (UI Message Stream Protocol v1):
- - data: {"type": "start", "messageId": "..."} - Message start
- - data: {"type": "text-start", "id": "..."} - Text block start
- - data: {"type": "text-delta", "id": "...", "delta": "..."} - Text chunks
- - data: {"type": "source-document", "sourceId": "...", ...} - Citations
- - data: {"type": "text-end", "id": "..."} - Text block end
- - data: {"type": "finish", "finishReason": "stop"} - Stream end
- """
- message_id = str(uuid.uuid4())
- text_id = str(uuid.uuid4())
- parser = CitationStreamParser(citations)
-
- # Message start (protocol requirement)
- yield f"data: {json.dumps({'type': 'start', 'messageId': message_id})}\n\n"
-
- # Text block start
- yield f"data: {json.dumps({'type': 'text-start', 'id': text_id})}\n\n"
-
- async for chunk in chunks:
- # Parse chunk for citations
- text_to_emit, new_citations = parser.parse_chunk(chunk)
-
- # Emit text delta if we have text
- if text_to_emit:
- yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': text_to_emit})}\n\n"
-
- # Emit source-document parts for detected citations
- for citation in new_citations:
- source_doc = citation.to_sse_dict()
- yield f"data: {json.dumps(source_doc)}\n\n"
-
- # Flush remaining buffer
- remaining = parser.flush()
- if remaining:
- yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': remaining})}\n\n"
-
- # Log citation usage for monitoring (ADR: Zero Citations case)
- if citations and not parser.emitted_ids:
- logger.info(
- "LLM response had no citations despite %d available chunks",
- len(citations),
- )
-
- # Text block end
- yield f"data: {json.dumps({'type': 'text-end', 'id': text_id})}\n\n"
-
- # Finish message
- yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n"
-
-
async def empty_workspace_response() -> AsyncGenerator[str, None]:
"""Return SSE stream for empty workspace message using AI SDK protocol.
diff --git a/tests/agents/rag/test_streaming.py b/tests/agents/rag/test_streaming.py
index 32a3e68..9e9c228 100644
--- a/tests/agents/rag/test_streaming.py
+++ b/tests/agents/rag/test_streaming.py
@@ -5,6 +5,7 @@
"""
import json
+import logging
from typing import Any, AsyncIterator
import pytest
@@ -305,3 +306,133 @@ def astream(self, inputs, stream_mode):
)
assert source_doc_idx < text_delta_idx, "Citations must appear before text"
+
+
+class TestCitationValidation:
+ """Test citation index validation for invalid references.
+
+ AC #4: Invalid citation indices should be logged as warnings and not crash.
+ """
+
+ async def test_build_citations_handles_empty_chunks(self):
+ """Test _build_citations_from_context handles empty list gracefully."""
+ from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
+
+ adapter = LangGraphToAISDKAdapter()
+
+ citations = adapter._build_citations_from_context([])
+
+ assert citations == []
+
+ async def test_build_citations_handles_missing_fields(self):
+ """Test _build_citations_from_context handles chunks with missing optional fields."""
+ from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
+
+ adapter = LangGraphToAISDKAdapter()
+
+ # Minimal chunk with only required fields missing
+ chunks = [
+ {
+ "content": "Some content",
+ # Missing: document_id, title, chunk_index, score
+ }
+ ]
+
+ citations = adapter._build_citations_from_context(chunks)
+
+ assert len(citations) == 1
+ assert citations[0].source_id == "cite-0"
+ # Should use defaults for missing fields
+ assert citations[0].title == "Unknown"
+
+ async def test_citation_indices_are_zero_based(self):
+ """Test that citation indices are 0-based (cite-0, cite-1, etc.)."""
+ from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
+
+ adapter = LangGraphToAISDKAdapter()
+
+ chunks = [
+ {
+ "content": "First",
+ "title": "doc1.pdf",
+ "document_id": "d1",
+ "chunk_index": 0,
+ "score": 0.9,
+ },
+ {
+ "content": "Second",
+ "title": "doc2.pdf",
+ "document_id": "d2",
+ "chunk_index": 1,
+ "score": 0.8,
+ },
+ {
+ "content": "Third",
+ "title": "doc3.pdf",
+ "document_id": "d3",
+ "chunk_index": 2,
+ "score": 0.7,
+ },
+ ]
+
+ citations = adapter._build_citations_from_context(chunks)
+
+ assert len(citations) == 3
+ assert citations[0].source_id == "cite-0"
+ assert citations[1].source_id == "cite-1"
+ assert citations[2].source_id == "cite-2"
+
+ async def test_citation_sse_format_matches_ai_sdk_protocol(self):
+ """Test that Citation.to_sse_dict() produces valid AI SDK source-document format."""
+ from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
+
+ adapter = LangGraphToAISDKAdapter()
+
+ chunks = [
+ {
+ "content": "Test content",
+ "title": "research.pdf",
+ "document_id": "doc-abc-123",
+ "chunk_index": 5,
+ "score": 0.95,
+ }
+ ]
+
+ citations = adapter._build_citations_from_context(chunks)
+ sse_dict = citations[0].to_sse_dict()
+
+ # Verify AI SDK source-document format
+ assert sse_dict["type"] == "source-document"
+ assert sse_dict["sourceId"] == "cite-0"
+ assert sse_dict["mediaType"] == "text/plain"
+ assert sse_dict["title"] == "research.pdf"
+
+ # Verify providerMetadata.ragitect structure
+ ragitect_meta = sse_dict["providerMetadata"]["ragitect"]
+ assert ragitect_meta["chunkIndex"] == 5
+ assert ragitect_meta["similarity"] == 0.95
+ assert ragitect_meta["preview"] == "Test content"
+ assert ragitect_meta["documentId"] == "doc-abc-123"
+
+ async def test_citation_uses_rerank_score_over_score(self):
+ """Test that rerank_score takes precedence over score when available."""
+ from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
+
+ adapter = LangGraphToAISDKAdapter()
+
+ chunks = [
+ {
+ "content": "Test",
+ "title": "doc.pdf",
+ "document_id": "d1",
+ "chunk_index": 0,
+ "score": 0.5, # Original vector similarity
+ "rerank_score": 0.95, # Reranker score should take precedence
+ }
+ ]
+
+ citations = adapter._build_citations_from_context(chunks)
+ sse_dict = citations[0].to_sse_dict()
+
+ # Should use rerank_score (0.95) not score (0.5)
+ assert sse_dict["providerMetadata"]["ragitect"]["similarity"] == 0.95
diff --git a/tests/api/v1/test_chat_streaming.py b/tests/api/v1/test_chat_streaming.py
index beb283d..0849824 100644
--- a/tests/api/v1/test_chat_streaming.py
+++ b/tests/api/v1/test_chat_streaming.py
@@ -23,7 +23,7 @@
def setup_langgraph_streaming_mocks(mocker):
"""Setup common mocks for LangGraph streaming architecture.
- After Story 4.3, the endpoint uses LangGraphToAISDKAdapter with full graph execution.
+ the endpoint uses LangGraphToAISDKAdapter with full graph execution.
This helper mocks all required dependencies for the streaming pipeline.
Args:
@@ -72,7 +72,7 @@ async def mock_embed_fn(model, text: str):
mock_structured_llm = mocker.AsyncMock()
# Mock strategy response for generate_strategy node
- from ragitect.agents.rag.schemas import SearchStrategy, Search
+ from ragitect.agents.rag.schemas import Search, SearchStrategy
mock_strategy = SearchStrategy(
reasoning="Test analysis",
@@ -141,7 +141,7 @@ async def test_stream_returns_sse_content_type(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(
@@ -182,7 +182,7 @@ async def test_stream_format_is_sse(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(
@@ -405,7 +405,7 @@ async def test_chat_retrieves_relevant_chunks(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(
@@ -443,7 +443,7 @@ async def test_chat_uses_context_in_prompt(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(
@@ -502,7 +502,7 @@ async def test_chat_with_provider_override_uses_specified_provider(
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
mocks = setup_langgraph_streaming_mocks(mocker)
# Track which provider was requested
@@ -643,7 +643,7 @@ async def test_chat_without_provider_uses_default(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
mocks = setup_langgraph_streaming_mocks(mocker)
mock_create_llm = mocker.AsyncMock()
@@ -665,277 +665,6 @@ async def test_chat_without_provider_uses_default(self, async_client, mocker):
assert call_kwargs.get("provider") is None
-class TestCitationStreaming:
- """Tests for citation detection and streaming."""
-
- async def test_citation_parser_uses_cite_format(self):
- """Test that CitationStreamParser uses [cite: N] format"""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"),
- Citation.from_context_chunk(2, "doc-id-2", "doc2.pdf", 1, 0.8, "Content 2"),
- ]
-
- parser = CitationStreamParser(citations)
-
- # Parse chunks with [cite: N] format
- text1, found1 = parser.parse_chunk("Python is great[cite: 1] and ")
- text2, found2 = parser.parse_chunk("versatile[cite: 2].")
- remaining = parser.flush()
-
- # Should find both citations with new format
- assert len(found1) == 1
- assert found1[0].source_id == "cite-1"
- assert len(found2) == 1
- assert found2[0].source_id == "cite-2"
-
- async def test_citation_parser_ignores_old_bare_bracket_format(self):
- """Test that parser does NOT match old [N] format"""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"),
- ]
-
- parser = CitationStreamParser(citations)
-
- # Parse chunks with OLD bare [N] format - should NOT match
- text1, found1 = parser.parse_chunk("Python is great[1] but this won't match.")
- remaining = parser.flush()
-
- # Should NOT find the citation with old format
- assert len(found1) == 0
-
- async def test_build_citation_metadata_creates_citations_from_chunks(self):
- """Test that build_citation_metadata creates Citation objects from context chunks (AC1, AC2)."""
- from ragitect.api.v1.chat import build_citation_metadata
-
- context_chunks = [
- {
- "content": "Python is a programming language used for many applications.",
- "document_name": "python-intro.pdf",
- "chunk_index": 0,
- "similarity": 0.95,
- },
- {
- "content": "FastAPI is a modern web framework for building APIs.",
- "document_name": "fastapi-docs.pdf",
- "chunk_index": 3,
- "rerank_score": 0.88, # Should use rerank_score over similarity
- },
- ]
-
- citations = build_citation_metadata(context_chunks)
-
- assert len(citations) == 2
- assert citations[0].source_id == "cite-1"
- assert citations[0].title == "python-intro.pdf"
- assert citations[1].source_id == "cite-2"
- assert citations[1].title == "fastapi-docs.pdf"
-
- async def test_citation_stream_parser_detects_markers(self):
- """Test that CitationStreamParser detects [N] markers in text (AC1)."""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(1, "doc-id-1", "doc1.pdf", 0, 0.9, "Content 1"),
- Citation.from_context_chunk(2, "doc-id-2", "doc2.pdf", 1, 0.8, "Content 2"),
- ]
-
- parser = CitationStreamParser(citations)
-
- # Parse chunks with citation markers using [cite: N] format
- text1, found1 = parser.parse_chunk("Python is great[cite: 1] and ")
- text2, found2 = parser.parse_chunk("versatile[cite: 2].")
- remaining = parser.flush()
-
- # Should find both citations
- assert len(found1) == 1
- assert found1[0].source_id == "cite-1"
- assert len(found2) == 1
- assert found2[0].source_id == "cite-2"
-
- async def test_citation_stream_parser_handles_split_markers(self):
- """Test that parser handles citation markers split across chunks (AC1)."""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(1, "doc-id", "doc.pdf", 0, 0.9, "Content"),
- ]
-
- parser = CitationStreamParser(citations)
-
- # Split "[cite: 1]" across chunks
- text1, found1 = parser.parse_chunk("Hello [cite:")
- text2, found2 = parser.parse_chunk(" 1] world")
- remaining = parser.flush()
-
- # Should eventually find the citation
- all_citations = found1 + found2
- assert len(all_citations) == 1
- assert all_citations[0].source_id == "cite-1"
-
- async def test_citation_stream_parser_ignores_invalid_citations(self, caplog):
- """Test that parser logs warning for invalid citation indices (AC6 - hallucination handling)."""
- import logging
-
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(0, "doc-id", "doc.pdf", 0, 0.9, "Content"),
- ]
-
- parser = CitationStreamParser(citations)
-
- with caplog.at_level(logging.WARNING, logger="ragitect.api.v1.chat"):
- # Try to cite [cite: 99] which doesn't exist
- text, found = parser.parse_chunk("Test [cite: 99] content")
- remaining = parser.flush()
-
- # Should not find any citations (invalid index)
- assert len(found) == 0
- # Should have logged a warning
- assert "cited non-existent source" in caplog.text or "99" in caplog.text
-
- async def test_citation_stream_emits_each_citation_once(self):
- """Test that each citation is only emitted once even if marker appears multiple times."""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import CitationStreamParser
-
- citations = [
- Citation.from_context_chunk(0, "doc-id", "doc.pdf", 0, 0.9, "Content"),
- ]
-
- parser = CitationStreamParser(citations)
-
- # Same citation marker appears twice (using [cite: N] format)
- assert len(citations) >= 1
- # Test with [cite: 1] which maps to index 0
- text1, found1 = parser.parse_chunk("First[cite: 1] and ")
- text2, found2 = parser.parse_chunk("again[cite: 1].")
- remaining = parser.flush()
-
- # Should only emit once
- total_found = len(found1) + len(found2)
- assert total_found == 1
-
- async def test_format_sse_stream_with_citations_emits_source_documents(self):
- """Test that format_sse_stream_with_citations emits source-document events (AC1, AC2)."""
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import format_sse_stream_with_citations
-
- citations = [
- Citation.from_context_chunk(
- 1, "doc-uuid", "intro.pdf", 0, 0.95, "Python is..."
- ),
- ]
-
- async def mock_chunks():
- yield "Python"
- yield " is great"
- yield "[cite: 1]"
- yield "."
-
- events = []
- async for event in format_sse_stream_with_citations(mock_chunks(), citations):
- events.append(event)
-
- # Should contain source-document event
- source_doc_events = [e for e in events if "source-document" in e]
- assert len(source_doc_events) >= 1
-
- # Verify source-document contains expected fields
- source_event = source_doc_events[0]
- # Expect cite-1 because Citation.from_context_chunk will now create cite-1 if we mock correctly or if we pass explicit ID
- # Wait, from_context_chunk uses index arg.
- # If we pass index=1 manually:
- assert "cite-1" in source_event
- assert "intro.pdf" in source_event
-
- async def test_format_sse_stream_with_citations_handles_zero_citations(
- self, caplog
- ):
- """Test that stream works when LLM doesn't cite any sources (AC6)."""
- import logging
-
- from ragitect.api.schemas.chat import Citation
- from ragitect.api.v1.chat import format_sse_stream_with_citations
-
- # Citations available but LLM doesn't use them
- citations = [
- Citation.from_context_chunk(1, "doc-id", "doc.pdf", 0, 0.9, "Content"),
- ]
-
- async def mock_chunks():
- yield "2 plus 2 equals 4." # No citations
-
- with caplog.at_level(logging.INFO, logger="ragitect.api.v1.chat"):
- events = []
- async for event in format_sse_stream_with_citations(
- mock_chunks(), citations
- ):
- events.append(event)
-
- # Should NOT contain source-document events
- source_doc_events = [e for e in events if "source-document" in e]
- assert len(source_doc_events) == 0
-
- # Should have text-delta events
- text_events = [e for e in events if "text-delta" in e]
- assert len(text_events) >= 1
-
- # Should log that no citations were used
- assert "no citations" in caplog.text.lower()
-
- async def test_chat_endpoint_emits_citation_metadata(self, async_client, mocker):
- """Test that chat endpoint includes source-document events in stream (AC1, AC2)."""
- workspace_id = uuid.uuid4()
- now = datetime.now(timezone.utc)
-
- from ragitect.services.database.models import Workspace
-
- mock_workspace = Workspace(id=workspace_id, name="Test")
- mock_workspace.created_at = now
- mock_workspace.updated_at = now
-
- mock_ws_repo = mocker.AsyncMock()
- mock_ws_repo.get_by_id.return_value = mock_workspace
-
- mocker.patch(
- "ragitect.api.v1.chat.WorkspaceRepository",
- return_value=mock_ws_repo,
- )
-
- mock_doc_repo = mocker.AsyncMock()
- mock_doc_repo.get_by_workspace_count.return_value = 5
-
- mocker.patch(
- "ragitect.api.v1.chat.DocumentRepository",
- return_value=mock_doc_repo,
- )
-
- # Setup LangGraph streaming mocks (Story 4.3)
- setup_langgraph_streaming_mocks(mocker)
-
- response = await async_client.post(
- f"/api/v1/workspaces/{workspace_id}/chat/stream",
- json={"message": "What is Python?"},
- )
-
- assert response.status_code == 200
- content = response.text
-
- # With LangGraph streaming, citations are handled by the adapter
- # Just verify we got a valid SSE stream
- assert "data:" in content
-
-
class TestLangGraphStreaming:
"""Tests for LangGraph streaming adapter integration in chat endpoint."""
@@ -980,7 +709,7 @@ async def test_langgraph_adapter_integration(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(
@@ -1032,7 +761,7 @@ async def test_sse_event_format_compliance(self, async_client, mocker):
return_value=mock_doc_repo,
)
- # Setup LangGraph streaming mocks (Story 4.3)
+ # Setup LangGraph streaming mocks
setup_langgraph_streaming_mocks(mocker)
response = await async_client.post(