Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions frontend/src/components/MessageWithCitations.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ function CitationBadge({
/**
* Process text to replace [cite: N] markers with citation badges
* ADR-3.4.1: Uses [cite: N] format to avoid false positives with [N]
*
* Citation Index Mapping:
* - LLM outputs: [cite: 1], [cite: 2], ..., [cite: N] (1-based, user-friendly)
* - Backend emits: cite-0, cite-1, ..., cite-(N-1) (0-based, internal)
* - Frontend displays: 1, 2, ..., N (1-based, user-friendly)
*
* Example: For 9 chunks
* - LLM: [cite: 9] → citationId: cite-8 → Display: 9
*/
function processTextWithCitations(
text: string,
Expand All @@ -149,15 +157,20 @@ function processTextWithCitations(
result.push(text.slice(lastIndex, match.index));
}

const citationIndex = parseInt(match[1], 10);
const citationId = `cite-${citationIndex}`;
// Parse the 1-based citation number from LLM output (e.g., [cite: 9])
const llmCitationNumber = parseInt(match[1], 10);

// Convert to 0-based citation ID for lookup (e.g., cite-8)
// LLM uses 1-based: [cite: 1] through [cite: N]
// Backend emits 0-based: cite-0 through cite-(N-1)
const citationId = `cite-${llmCitationNumber - 1}`;
const citation = citations[citationId];

if (citation) {
result.push(
<CitationBadge
key={`cite-${keyIndex++}`}
citationIndex={citationIndex}
citationIndex={llmCitationNumber} // Display 1-based number to user
citation={citation}
onCitationClick={onCitationClick}
/>
Expand All @@ -166,7 +179,7 @@ function processTextWithCitations(
// Graceful degradation for invalid citations
result.push(
<span key={`invalid-${keyIndex++}`} className="text-muted-foreground">
[{citationIndex}]
[{llmCitationNumber}]
</span>
);
}
Expand Down
13 changes: 7 additions & 6 deletions frontend/src/components/__tests__/CitationDeepDive.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ describe('Citation Deep-Dive Integration', () => {

render(
<MessageWithCitations
content="The architecture is well documented [cite: 0]."
content="The architecture is well documented [cite: 1]."
citations={mockCitations}
onCitationClick={handleCitationClick}
/>
);

const citationBadge = screen.getByRole('button', { name: /Citation 0/i });
const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
await user.click(citationBadge);

expect(handleCitationClick).toHaveBeenCalledWith(mockCitations['cite-0']);
Expand All @@ -77,17 +77,18 @@ describe('Citation Deep-Dive Integration', () => {

render(
<MessageWithCitations
content="Check the deployment guide [cite: 1]."
content="Check the deployment guide [cite: 2]."
citations={mockCitations}
onCitationClick={handleCitationClick}
/>
);

const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
const citationBadge = screen.getByRole('button', { name: /Citation 2/i });
await user.click(citationBadge);

expect(receivedCitation).not.toBeNull();
expect(receivedCitation!.documentId).toBe('doc-uuid-456');
expect(receivedCitation!.id).toBe('cite-1');
});
});

Expand All @@ -110,13 +111,13 @@ describe('Citation Deep-Dive Integration', () => {

render(
<MessageWithCitations
content="Check this [cite: 0]."
content="Check this [cite: 1]."
citations={citationWithoutDocId}
onCitationClick={handleCitationClick}
/>
);

const citationBadge = screen.getByRole('button', { name: /Citation 0/i });
const citationBadge = screen.getByRole('button', { name: /Citation 1/i });
await user.click(citationBadge);

// The handler is called - workspace page validates and shows error
Expand Down
12 changes: 7 additions & 5 deletions frontend/src/components/__tests__/MessageWithCitations.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,20 @@ import { MessageWithCitations } from '../MessageWithCitations';
import type { CitationMap } from '@/types/citation';

describe('MessageWithCitations', () => {
// Mock data uses 0-based citation IDs to match backend format
// LLM outputs [cite: 1], [cite: 2] → Maps to cite-0, cite-1
const mockCitations: CitationMap = {
'cite-1': {
id: 'cite-1',
'cite-0': {
id: 'cite-0',
title: 'intro.pdf',
mediaType: 'text/plain',
chunkIndex: 0,
similarity: 0.95,
preview: 'Python is a powerful programming language used worldwide...',
documentId: 'doc-uuid-1',
},
'cite-2': {
id: 'cite-2',
'cite-1': {
id: 'cite-1',
title: 'advanced.pdf',
mediaType: 'text/plain',
chunkIndex: 5,
Expand Down Expand Up @@ -152,7 +154,7 @@ describe('MessageWithCitations', () => {
const button = screen.getByRole('button');
await user.click(button);

expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-1']);
expect(handleClick).toHaveBeenCalledWith(mockCitations['cite-0']);
});
});

Expand Down
33 changes: 30 additions & 3 deletions ragitect/agents/rag/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class LangGraphToAISDKAdapter:
[cite: N] markers split across token chunks (20-char lookahead).

Reference Implementation:
- See git history for CitationStreamParser (pre-Story 4.3)
- See git history for CitationStreamParser
- Production examples: docs/research/2025-12-28-ENG-4.3-langgraph-streaming-adapter.md

Example:
Expand Down Expand Up @@ -98,11 +98,38 @@ def _build_citations_from_context(
) -> list[Citation]:
"""Build Citation objects from context chunks.

Node-Level Streaming Simplification:
================================================
This method emits ALL context chunks as citations upfront, not just
the ones referenced in the LLM response. This is intentional because:

1. **No Buffer Needed**: Node-level streaming (stream_mode="updates")
provides complete text per node. Citation markers like [cite: 1]
are never split across chunks, unlike token-level streaming.

2. **Indexing Scheme**:
- Internal (Python): 0-based (enumerate index)
- LLM Prompt: 1-based ([Chunk 1], [Chunk 2] in context)
- SSE sourceId: 0-based (cite-0, cite-1)
- Frontend display: 1-based (Citation 1, 2, etc.)

3. **Frontend Compatibility**: AI SDK useChat() receives all source
documents before text-delta, enabling immediate citation rendering.

4. **Future Enhancement**: Detect only cited sources via regex:
```python
pattern = re.compile(r"\\[cite:\\s*(\\d+)\\]")
for match in pattern.finditer(full_text):
cite_idx = int(match.group(1)) - 1 # 1-based to 0-based
```
This would filter to only actually-cited sources.

Args:
context_chunks: List of context chunk dicts from merge_context node
context_chunks: List of context chunk dicts from merge_context node.
Expected keys: document_id, title, chunk_index, score/rerank_score, content

Returns:
List of Citation objects ready for streaming
List of Citation objects ready for streaming as source-document events.
"""
citations = []
for idx, chunk in enumerate(context_chunks):
Expand Down
170 changes: 0 additions & 170 deletions ragitect/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import json
import logging
import re
import uuid
from collections.abc import AsyncGenerator
from uuid import UUID
Expand All @@ -25,7 +24,6 @@
from ragitect.agents.rag import build_rag_graph
from ragitect.agents.rag.state import RAGState
from ragitect.agents.rag.streaming import LangGraphToAISDKAdapter
from ragitect.api.schemas.chat import Citation
from ragitect.prompts.rag_prompts import build_rag_system_prompt
from ragitect.services.config import EmbeddingConfig
from ragitect.services.database.connection import get_async_session
Expand Down Expand Up @@ -111,174 +109,6 @@ async def format_sse_stream(
yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n"


def build_citation_metadata(context_chunks: list[dict]) -> list[Citation]:
"""Build citation metadata from context chunks.

NOTE: Prompt engineering for [N] format
This function just prepares metadata for frontend consumption.

Args:
context_chunks: Chunks returned from retrieve_context()

Returns:
List of Citation objects for streaming
"""
citations = []
for i, chunk in enumerate(context_chunks):
citations.append(
Citation.from_context_chunk(
index=i + 1, # 1-based index
document_id=str(chunk.get("document_id", "")),
document_name=chunk.get("document_name", "Unknown"),
chunk_index=chunk.get("chunk_index", 0),
similarity=chunk.get("rerank_score") or chunk.get("similarity", 0.0),
content=chunk.get("content", ""),
)
)
return citations


class CitationStreamParser:
"""Stateful parser for detecting citations across chunk boundaries.

ADR Decision: Real-time regex streaming with cross-chunk buffering.
Handles edge case where LLM outputs '[cite:' in one chunk and '0]' in next.

ADR-3.4.1: Citation format changed from [N] to [cite: N] to avoid
false positives with markdown lists and array indices.
"""

def __init__(self, citations: list[Citation]):
"""Initialize parser with available citations.

Args:
citations: Pre-built citation metadata from context chunks
"""
self.citations = citations
self.buffer = "" # Buffer for partial citation markers
self.emitted_ids: set[str] = set() # Track which citations already emitted
# ADR-3.4.1: Updated pattern from [N] to [cite: N] format
self.pattern = re.compile(r"\[cite:\s*(\d+)\]")

def parse_chunk(self, chunk: str) -> tuple[str, list[Citation]]:
"""Parse chunk and detect citation markers.

Args:
chunk: New text chunk from LLM stream

Returns:
Tuple of (text_to_emit, new_citations_found)
"""
# Add chunk to buffer
self.buffer += chunk

# Find all complete citation markers in buffer
new_citations = []
for match in self.pattern.finditer(self.buffer):
cite_idx = int(match.group(1))
cite_id = f"cite-{cite_idx}"

# Validate citation index (ADR: Hallucination Handling)
# 1-based index check
if cite_idx < 1 or cite_idx > len(self.citations):
logger.warning(
"LLM cited non-existent source [%d] (only %d chunks available)",
cite_idx,
len(self.citations),
)
continue # Graceful degradation - skip invalid citation

# Emit each citation only once
if cite_id not in self.emitted_ids:
# Map 1-based index to 0-based list
new_citations.append(self.citations[cite_idx - 1])
self.emitted_ids.add(cite_id)

# Emit text, but keep last 20 chars in buffer for partial markers
# Max citation marker length: "[cite: 9999]" = 12 chars, buffer 20 for safety
if len(self.buffer) > 20:
text_to_emit = self.buffer[:-20]
self.buffer = self.buffer[-20:]
else:
text_to_emit = ""

return text_to_emit, new_citations

def flush(self) -> str:
"""Flush remaining buffer at end of stream.

Returns:
Any remaining text in the buffer
"""
remaining = self.buffer
self.buffer = ""
return remaining


async def format_sse_stream_with_citations(
chunks: AsyncGenerator[str, None],
citations: list[Citation],
) -> AsyncGenerator[str, None]:
"""Format LLM chunks with AI SDK UI Message Stream Protocol v1 + citations.

ADR: Real-time regex streaming with cross-chunk buffering.
Emits citations as 'source-document' parts for AI SDK useChat.

Args:
chunks: LLM token stream
citations: Pre-built citation metadata

Yields:
SSE formatted messages (UI Message Stream Protocol v1):
- data: {"type": "start", "messageId": "..."} - Message start
- data: {"type": "text-start", "id": "..."} - Text block start
- data: {"type": "text-delta", "id": "...", "delta": "..."} - Text chunks
- data: {"type": "source-document", "sourceId": "...", ...} - Citations
- data: {"type": "text-end", "id": "..."} - Text block end
- data: {"type": "finish", "finishReason": "stop"} - Stream end
"""
message_id = str(uuid.uuid4())
text_id = str(uuid.uuid4())
parser = CitationStreamParser(citations)

# Message start (protocol requirement)
yield f"data: {json.dumps({'type': 'start', 'messageId': message_id})}\n\n"

# Text block start
yield f"data: {json.dumps({'type': 'text-start', 'id': text_id})}\n\n"

async for chunk in chunks:
# Parse chunk for citations
text_to_emit, new_citations = parser.parse_chunk(chunk)

# Emit text delta if we have text
if text_to_emit:
yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': text_to_emit})}\n\n"

# Emit source-document parts for detected citations
for citation in new_citations:
source_doc = citation.to_sse_dict()
yield f"data: {json.dumps(source_doc)}\n\n"

# Flush remaining buffer
remaining = parser.flush()
if remaining:
yield f"data: {json.dumps({'type': 'text-delta', 'id': text_id, 'delta': remaining})}\n\n"

# Log citation usage for monitoring (ADR: Zero Citations case)
if citations and not parser.emitted_ids:
logger.info(
"LLM response had no citations despite %d available chunks",
len(citations),
)

# Text block end
yield f"data: {json.dumps({'type': 'text-end', 'id': text_id})}\n\n"

# Finish message
yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n"


async def empty_workspace_response() -> AsyncGenerator[str, None]:
"""Return SSE stream for empty workspace message using AI SDK protocol.

Expand Down
Loading
Loading