Skip to content

Commit 58d5547

Browse files
authored
refactor remove old hardcoded pipeline (#32)
1 parent ed5d876 commit 58d5547

14 files changed

Lines changed: 25 additions & 3895 deletions

.env.example

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,3 @@ RETRIEVAL_ADAPTIVE_K_MIN=4
5959
RETRIEVAL_ADAPTIVE_K_MAX=16
6060
RETRIEVAL_ADAPTIVE_K_GAP_THRESHOLD=0.15
6161
RETRIEVAL_TOKEN_BUDGET=4000
62-
63-
# Agent Configuration
64-
# USE_LANGGRAPH_RETRIEVAL: Enable LangGraph-based agent pipeline for retrieval (default: false)
65-
# Set to 'true' to use the strategy-based search pipeline (Story 4.2)
66-
USE_LANGGRAPH_RETRIEVAL=false

ragitect/api/v1/chat.py

Lines changed: 15 additions & 234 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Chat streaming endpoint using Server-Sent Events (SSE) with RAG integration.
22
33
This module provides the SSE streaming endpoint for chat functionality with
4-
full Retrieval-Augmented Generation (RAG) integration.
4+
full Retrieval-Augmented Generation (RAG) integration using LangGraph agent-based pipeline.
55
6-
Supports two retrieval modes:
7-
1. Legacy: Manual orchestration via retrieve_context()
8-
2. LangGraph: Agent-based pipeline via retrieve_context_with_graph()
6+
The RAG pipeline uses intelligent query decomposition with parallel search execution:
7+
- Strategy generation: LLM decomposes queries into 1-5 targeted search terms
8+
- Parallel search: Each term searches independently with reranking, MMR, and adaptive-K
9+
- Context merging: Deduplicate and re-rank aggregated results for final context
910
"""
1011

1112
import json
@@ -27,22 +28,7 @@
2728
from ragitect.agents.rag.state import RAGState
2829
from ragitect.api.schemas.chat import Citation
2930
from ragitect.prompts.rag_prompts import build_rag_system_prompt
30-
from ragitect.services.adaptive_k import select_adaptive_k
31-
from ragitect.services.config import (
32-
DEFAULT_RETRIEVAL_K,
33-
DEFAULT_SIMILARITY_THRESHOLD,
34-
RETRIEVAL_ADAPTIVE_K_GAP_THRESHOLD,
35-
RETRIEVAL_ADAPTIVE_K_MAX,
36-
RETRIEVAL_ADAPTIVE_K_MIN,
37-
RETRIEVAL_INITIAL_K,
38-
RETRIEVAL_MMR_K,
39-
RETRIEVAL_MMR_LAMBDA,
40-
RETRIEVAL_RERANKER_TOP_K,
41-
RETRIEVAL_USE_ADAPTIVE_K,
42-
RETRIEVAL_USE_MMR,
43-
RETRIEVAL_USE_RERANKER,
44-
EmbeddingConfig,
45-
)
31+
from ragitect.services.config import EmbeddingConfig
4632
from ragitect.services.database.connection import get_async_session
4733
from ragitect.services.database.repositories.document_repo import DocumentRepository
4834
from ragitect.services.database.repositories.vector_repo import VectorRepository
@@ -51,14 +37,7 @@
5137
from ragitect.services.llm import generate_response_stream
5238
from ragitect.services.llm_config_service import get_active_embedding_config
5339
from ragitect.services.llm_factory import create_llm_with_provider
54-
from ragitect.services.mmr import mmr_select
55-
from ragitect.services.query_service import query_with_iterative_fallback
56-
from ragitect.services.reranker import rerank_chunks
5740

58-
# Feature flag for LangGraph-based retrieval
59-
USE_LANGGRAPH_RETRIEVAL = (
60-
os.environ.get("USE_LANGGRAPH_RETRIEVAL", "false").lower() == "true"
61-
)
6241

6342
# Compile graph once at module level (performance optimization)
6443
# Graph compilation is expensive - do it once, reuse across requests
@@ -324,193 +303,6 @@ async def empty_workspace_response() -> AsyncGenerator[str, None]:
324303
yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n"
325304

326305

327-
async def retrieve_context(
328-
session: AsyncSession,
329-
workspace_id: UUID,
330-
query: str,
331-
chat_history: list[dict[str, str]],
332-
provider: str | None = None,
333-
initial_k: int = RETRIEVAL_INITIAL_K,
334-
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
335-
use_reranker: bool = RETRIEVAL_USE_RERANKER,
336-
use_mmr: bool = RETRIEVAL_USE_MMR,
337-
use_adaptive_k: bool = RETRIEVAL_USE_ADAPTIVE_K,
338-
mmr_lambda: float = RETRIEVAL_MMR_LAMBDA,
339-
) -> list[dict]:
340-
"""Retrieve relevant context chunks using multi-stage retrieval pipeline.
341-
342-
Pipeline stages:
343-
1. Over-retrieve: Get top-50 candidates (AC1)
344-
2. Rerank: Use cross-encoder for accurate relevance scoring (AC2)
345-
3. MMR: Apply diversity selection to reduce redundancy (AC3)
346-
4. Adaptive-K: Select K based on score distribution gaps (AC4)
347-
348-
Uses query_with_iterative_fallback for intelligent query processing:
349-
- Classifies query complexity (simple/ambiguous/complex)
350-
- For simple queries: tries direct search, falls back to reformulation if low relevance
351-
- For ambiguous/complex: reformulates directly with chat history context
352-
353-
Args:
354-
session: Database session
355-
workspace_id: Workspace to search
356-
query: User query
357-
chat_history: Previous conversation for context
358-
provider: Optional provider override for query processing LLM
359-
initial_k: Number of candidates for over-retrieval (default 50)
360-
similarity_threshold: Minimum similarity for initial retrieval
361-
use_reranker: Whether to apply cross-encoder reranking
362-
use_mmr: Whether to apply MMR diversity selection
363-
use_adaptive_k: Whether to use adaptive K selection
364-
mmr_lambda: Balance between relevance and diversity (0-1)
365-
366-
Returns:
367-
List of chunks with content and metadata
368-
"""
369-
# Get LLM for query optimization (uses provider override if specified)
370-
llm = await create_llm_with_provider(session, provider=provider)
371-
372-
# Get embedding configuration and create model
373-
embedding_config = await get_active_embedding_config(session)
374-
375-
# Build EmbeddingConfig from database config
376-
if embedding_config:
377-
config = EmbeddingConfig(
378-
provider=embedding_config.provider_name,
379-
model=embedding_config.model_name or "nomic-embed-text",
380-
base_url=embedding_config.config_data.get("base_url"),
381-
api_key=embedding_config.config_data.get("api_key"),
382-
dimension=embedding_config.config_data.get("dimension", 768),
383-
)
384-
else:
385-
config = EmbeddingConfig() # Use defaults (Ollama)
386-
387-
embedding_model = create_embeddings_model(config)
388-
389-
# Store search results and embeddings for pipeline stages
390-
search_results_cache: dict[str, list[tuple]] = {}
391-
query_embedding_cache: dict[str, list[float]] = {}
392-
393-
# Create vector search function for iterative fallback
394-
async def vector_search_fn(search_query: str) -> list[str]:
395-
"""Perform vector search and return chunk contents (caches full results)."""
396-
query_embedding = await embed_text(embedding_model, search_query)
397-
query_embedding_cache[search_query] = query_embedding
398-
repo = VectorRepository(session)
399-
# Stage 1: Over-retrieve (AC1) - get more candidates for reranking
400-
chunks_with_scores = await repo.search_similar_chunks(
401-
workspace_id,
402-
query_embedding,
403-
k=initial_k,
404-
similarity_threshold=similarity_threshold,
405-
)
406-
# Cache full results for later use
407-
search_results_cache[search_query] = chunks_with_scores
408-
return [chunk.content for chunk, _distance in chunks_with_scores]
409-
410-
# Use iterative fallback for intelligent query processing and retrieval
411-
retrieved_contents, metadata = await query_with_iterative_fallback(
412-
llm, query, chat_history, vector_search_fn
413-
)
414-
415-
final_query = metadata.get("final_query", query)
416-
logger.info(
417-
"Query processed: '%s' -> '%s' (classification=%s, used_reformulation=%s)",
418-
query,
419-
final_query,
420-
metadata.get("classification"),
421-
metadata.get("used_reformulation"),
422-
)
423-
424-
# Use cached search results to avoid duplicate retrieval
425-
chunks_with_scores = search_results_cache.get(final_query, [])
426-
query_embedding = query_embedding_cache.get(final_query, [])
427-
428-
# Log initial retrieval stats (AC6)
429-
if chunks_with_scores:
430-
similarities = [1.0 - dist for _, dist in chunks_with_scores]
431-
logger.info(
432-
"Initial retrieval: %d chunks, similarity range [%.3f, %.3f], mean: %.3f",
433-
len(chunks_with_scores),
434-
min(similarities),
435-
max(similarities),
436-
sum(similarities) / len(similarities),
437-
)
438-
439-
# Format chunks for processing pipeline
440-
doc_repo = DocumentRepository(session)
441-
chunks = []
442-
for chunk, distance in chunks_with_scores:
443-
# Load the parent document to get filename
444-
document = await doc_repo.get_by_id(chunk.document_id)
445-
446-
chunk_dict = {
447-
"content": chunk.content,
448-
"document_name": document.file_name if document else "Unknown",
449-
"document_id": str(chunk.document_id),
450-
"chunk_index": chunk.chunk_index,
451-
"similarity": 1.0 - distance, # Convert distance to similarity
452-
"embedding": list(chunk.embedding) if chunk.embedding is not None else [],
453-
}
454-
chunks.append(chunk_dict)
455-
456-
# Stage 2: Rerank with cross-encoder (AC2)
457-
if use_reranker and chunks:
458-
rerank_start = time.time()
459-
chunks = await rerank_chunks(
460-
final_query, chunks, top_k=RETRIEVAL_RERANKER_TOP_K
461-
)
462-
rerank_latency = (time.time() - rerank_start) * 1000
463-
logger.info(
464-
"Reranker latency: %.1fms for %d chunks", rerank_latency, len(chunks)
465-
)
466-
467-
# Stage 3: MMR diversity selection (AC3)
468-
if use_mmr and chunks and query_embedding:
469-
chunk_embeddings = [c.get("embedding", []) for c in chunks]
470-
# Filter out chunks without embeddings
471-
valid_chunks = [(c, e) for c, e in zip(chunks, chunk_embeddings) if len(e) > 0]
472-
if valid_chunks:
473-
valid_chunk_list = [c for c, _ in valid_chunks]
474-
valid_embeddings = [e for _, e in valid_chunks]
475-
chunks = mmr_select(
476-
query_embedding=query_embedding,
477-
chunk_embeddings=valid_embeddings,
478-
chunks=valid_chunk_list,
479-
k=RETRIEVAL_MMR_K,
480-
lambda_param=mmr_lambda,
481-
)
482-
logger.info(
483-
"MMR selected %d diverse chunks (lambda=%.2f)", len(chunks), mmr_lambda
484-
)
485-
486-
# Stage 4: Adaptive-K selection (AC4)
487-
if use_adaptive_k and chunks:
488-
chunks, k_metadata = select_adaptive_k(
489-
chunks,
490-
score_key="rerank_score" if use_reranker else "similarity",
491-
k_min=RETRIEVAL_ADAPTIVE_K_MIN,
492-
k_max=RETRIEVAL_ADAPTIVE_K_MAX,
493-
gap_threshold=RETRIEVAL_ADAPTIVE_K_GAP_THRESHOLD,
494-
)
495-
logger.info(
496-
"Adaptive-K: selected %d chunks (gap_found=%s)",
497-
k_metadata["adaptive_k"],
498-
k_metadata["gap_found"],
499-
)
500-
elif not use_adaptive_k:
501-
chunks = chunks[:DEFAULT_RETRIEVAL_K] # Fallback to fixed K
502-
503-
# Clean up: remove embedding from final results (not needed for prompt)
504-
results = []
505-
for i, chunk in enumerate(chunks):
506-
chunk_copy = {k: v for k, v in chunk.items() if k != "embedding"}
507-
chunk_copy["chunk_label"] = f"Chunk {i + 1}" # 1-based for citation binding
508-
results.append(chunk_copy)
509-
510-
logger.info("Retrieved %d context chunks after full pipeline", len(results))
511-
return results
512-
513-
514306
async def retrieve_context_with_graph(
515307
session: AsyncSession,
516308
workspace_id: UUID,
@@ -736,26 +528,15 @@ async def chat_stream(
736528
},
737529
)
738530

739-
# Retrieve context from documents (AC2)
740-
# Use LangGraph-based pipeline if enabled, otherwise legacy manual orchestration
741-
if USE_LANGGRAPH_RETRIEVAL:
742-
logger.info("Using LangGraph-based retrieval pipeline")
743-
context_chunks = await retrieve_context_with_graph(
744-
session,
745-
workspace_id,
746-
request.message,
747-
request.chat_history,
748-
provider=request.provider,
749-
)
750-
else:
751-
# Legacy: Pass provider override to use consistent LLM for query processing
752-
context_chunks = await retrieve_context(
753-
session,
754-
workspace_id,
755-
request.message,
756-
request.chat_history,
757-
provider=request.provider,
758-
)
531+
# Retrieve context from documents using LangGraph-based pipeline
532+
logger.info("Using LangGraph-based retrieval pipeline")
533+
context_chunks = await retrieve_context_with_graph(
534+
session,
535+
workspace_id,
536+
request.message,
537+
request.chat_history,
538+
provider=request.provider,
539+
)
759540

760541
# Build citation metadata from context chunks
761542
citations = build_citation_metadata(context_chunks)

0 commit comments

Comments
 (0)