|
1 | 1 | """Chat streaming endpoint using Server-Sent Events (SSE) with RAG integration. |
2 | 2 |
|
3 | 3 | This module provides the SSE streaming endpoint for chat functionality with |
4 | | -full Retrieval-Augmented Generation (RAG) integration. |
| 4 | +full Retrieval-Augmented Generation (RAG) integration using LangGraph agent-based pipeline. |
5 | 5 |
|
6 | | -Supports two retrieval modes: |
7 | | -1. Legacy: Manual orchestration via retrieve_context() |
8 | | -2. LangGraph: Agent-based pipeline via retrieve_context_with_graph() |
| 6 | +The RAG pipeline uses intelligent query decomposition with parallel search execution: |
| 7 | +- Strategy generation: LLM decomposes queries into 1-5 targeted search terms |
| 8 | +- Parallel search: Each term searches independently with reranking, MMR, and adaptive-K |
| 9 | +- Context merging: Deduplicate and re-rank aggregated results for final context |
9 | 10 | """ |
10 | 11 |
|
11 | 12 | import json |
|
27 | 28 | from ragitect.agents.rag.state import RAGState |
28 | 29 | from ragitect.api.schemas.chat import Citation |
29 | 30 | from ragitect.prompts.rag_prompts import build_rag_system_prompt |
30 | | -from ragitect.services.adaptive_k import select_adaptive_k |
31 | | -from ragitect.services.config import ( |
32 | | - DEFAULT_RETRIEVAL_K, |
33 | | - DEFAULT_SIMILARITY_THRESHOLD, |
34 | | - RETRIEVAL_ADAPTIVE_K_GAP_THRESHOLD, |
35 | | - RETRIEVAL_ADAPTIVE_K_MAX, |
36 | | - RETRIEVAL_ADAPTIVE_K_MIN, |
37 | | - RETRIEVAL_INITIAL_K, |
38 | | - RETRIEVAL_MMR_K, |
39 | | - RETRIEVAL_MMR_LAMBDA, |
40 | | - RETRIEVAL_RERANKER_TOP_K, |
41 | | - RETRIEVAL_USE_ADAPTIVE_K, |
42 | | - RETRIEVAL_USE_MMR, |
43 | | - RETRIEVAL_USE_RERANKER, |
44 | | - EmbeddingConfig, |
45 | | -) |
| 31 | +from ragitect.services.config import EmbeddingConfig |
46 | 32 | from ragitect.services.database.connection import get_async_session |
47 | 33 | from ragitect.services.database.repositories.document_repo import DocumentRepository |
48 | 34 | from ragitect.services.database.repositories.vector_repo import VectorRepository |
|
51 | 37 | from ragitect.services.llm import generate_response_stream |
52 | 38 | from ragitect.services.llm_config_service import get_active_embedding_config |
53 | 39 | from ragitect.services.llm_factory import create_llm_with_provider |
54 | | -from ragitect.services.mmr import mmr_select |
55 | | -from ragitect.services.query_service import query_with_iterative_fallback |
56 | | -from ragitect.services.reranker import rerank_chunks |
57 | 40 |
|
58 | | -# Feature flag for LangGraph-based retrieval |
59 | | -USE_LANGGRAPH_RETRIEVAL = ( |
60 | | - os.environ.get("USE_LANGGRAPH_RETRIEVAL", "false").lower() == "true" |
61 | | -) |
62 | 41 |
|
63 | 42 | # Compile graph once at module level (performance optimization) |
64 | 43 | # Graph compilation is expensive - do it once, reuse across requests |
@@ -324,193 +303,6 @@ async def empty_workspace_response() -> AsyncGenerator[str, None]: |
324 | 303 | yield f"data: {json.dumps({'type': 'finish', 'finishReason': 'stop'})}\n\n" |
325 | 304 |
|
326 | 305 |
|
327 | | -async def retrieve_context( |
328 | | - session: AsyncSession, |
329 | | - workspace_id: UUID, |
330 | | - query: str, |
331 | | - chat_history: list[dict[str, str]], |
332 | | - provider: str | None = None, |
333 | | - initial_k: int = RETRIEVAL_INITIAL_K, |
334 | | - similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, |
335 | | - use_reranker: bool = RETRIEVAL_USE_RERANKER, |
336 | | - use_mmr: bool = RETRIEVAL_USE_MMR, |
337 | | - use_adaptive_k: bool = RETRIEVAL_USE_ADAPTIVE_K, |
338 | | - mmr_lambda: float = RETRIEVAL_MMR_LAMBDA, |
339 | | -) -> list[dict]: |
340 | | - """Retrieve relevant context chunks using multi-stage retrieval pipeline. |
341 | | -
|
342 | | - Pipeline stages: |
343 | | - 1. Over-retrieve: Get top-50 candidates (AC1) |
344 | | - 2. Rerank: Use cross-encoder for accurate relevance scoring (AC2) |
345 | | - 3. MMR: Apply diversity selection to reduce redundancy (AC3) |
346 | | - 4. Adaptive-K: Select K based on score distribution gaps (AC4) |
347 | | -
|
348 | | - Uses query_with_iterative_fallback for intelligent query processing: |
349 | | - - Classifies query complexity (simple/ambiguous/complex) |
350 | | - - For simple queries: tries direct search, falls back to reformulation if low relevance |
351 | | - - For ambiguous/complex: reformulates directly with chat history context |
352 | | -
|
353 | | - Args: |
354 | | - session: Database session |
355 | | - workspace_id: Workspace to search |
356 | | - query: User query |
357 | | - chat_history: Previous conversation for context |
358 | | - provider: Optional provider override for query processing LLM |
359 | | - initial_k: Number of candidates for over-retrieval (default 50) |
360 | | - similarity_threshold: Minimum similarity for initial retrieval |
361 | | - use_reranker: Whether to apply cross-encoder reranking |
362 | | - use_mmr: Whether to apply MMR diversity selection |
363 | | - use_adaptive_k: Whether to use adaptive K selection |
364 | | - mmr_lambda: Balance between relevance and diversity (0-1) |
365 | | -
|
366 | | - Returns: |
367 | | - List of chunks with content and metadata |
368 | | - """ |
369 | | - # Get LLM for query optimization (uses provider override if specified) |
370 | | - llm = await create_llm_with_provider(session, provider=provider) |
371 | | - |
372 | | - # Get embedding configuration and create model |
373 | | - embedding_config = await get_active_embedding_config(session) |
374 | | - |
375 | | - # Build EmbeddingConfig from database config |
376 | | - if embedding_config: |
377 | | - config = EmbeddingConfig( |
378 | | - provider=embedding_config.provider_name, |
379 | | - model=embedding_config.model_name or "nomic-embed-text", |
380 | | - base_url=embedding_config.config_data.get("base_url"), |
381 | | - api_key=embedding_config.config_data.get("api_key"), |
382 | | - dimension=embedding_config.config_data.get("dimension", 768), |
383 | | - ) |
384 | | - else: |
385 | | - config = EmbeddingConfig() # Use defaults (Ollama) |
386 | | - |
387 | | - embedding_model = create_embeddings_model(config) |
388 | | - |
389 | | - # Store search results and embeddings for pipeline stages |
390 | | - search_results_cache: dict[str, list[tuple]] = {} |
391 | | - query_embedding_cache: dict[str, list[float]] = {} |
392 | | - |
393 | | - # Create vector search function for iterative fallback |
394 | | - async def vector_search_fn(search_query: str) -> list[str]: |
395 | | - """Perform vector search and return chunk contents (caches full results).""" |
396 | | - query_embedding = await embed_text(embedding_model, search_query) |
397 | | - query_embedding_cache[search_query] = query_embedding |
398 | | - repo = VectorRepository(session) |
399 | | - # Stage 1: Over-retrieve (AC1) - get more candidates for reranking |
400 | | - chunks_with_scores = await repo.search_similar_chunks( |
401 | | - workspace_id, |
402 | | - query_embedding, |
403 | | - k=initial_k, |
404 | | - similarity_threshold=similarity_threshold, |
405 | | - ) |
406 | | - # Cache full results for later use |
407 | | - search_results_cache[search_query] = chunks_with_scores |
408 | | - return [chunk.content for chunk, _distance in chunks_with_scores] |
409 | | - |
410 | | - # Use iterative fallback for intelligent query processing and retrieval |
411 | | - retrieved_contents, metadata = await query_with_iterative_fallback( |
412 | | - llm, query, chat_history, vector_search_fn |
413 | | - ) |
414 | | - |
415 | | - final_query = metadata.get("final_query", query) |
416 | | - logger.info( |
417 | | - "Query processed: '%s' -> '%s' (classification=%s, used_reformulation=%s)", |
418 | | - query, |
419 | | - final_query, |
420 | | - metadata.get("classification"), |
421 | | - metadata.get("used_reformulation"), |
422 | | - ) |
423 | | - |
424 | | - # Use cached search results to avoid duplicate retrieval |
425 | | - chunks_with_scores = search_results_cache.get(final_query, []) |
426 | | - query_embedding = query_embedding_cache.get(final_query, []) |
427 | | - |
428 | | - # Log initial retrieval stats (AC6) |
429 | | - if chunks_with_scores: |
430 | | - similarities = [1.0 - dist for _, dist in chunks_with_scores] |
431 | | - logger.info( |
432 | | - "Initial retrieval: %d chunks, similarity range [%.3f, %.3f], mean: %.3f", |
433 | | - len(chunks_with_scores), |
434 | | - min(similarities), |
435 | | - max(similarities), |
436 | | - sum(similarities) / len(similarities), |
437 | | - ) |
438 | | - |
439 | | - # Format chunks for processing pipeline |
440 | | - doc_repo = DocumentRepository(session) |
441 | | - chunks = [] |
442 | | - for chunk, distance in chunks_with_scores: |
443 | | - # Load the parent document to get filename |
444 | | - document = await doc_repo.get_by_id(chunk.document_id) |
445 | | - |
446 | | - chunk_dict = { |
447 | | - "content": chunk.content, |
448 | | - "document_name": document.file_name if document else "Unknown", |
449 | | - "document_id": str(chunk.document_id), |
450 | | - "chunk_index": chunk.chunk_index, |
451 | | - "similarity": 1.0 - distance, # Convert distance to similarity |
452 | | - "embedding": list(chunk.embedding) if chunk.embedding is not None else [], |
453 | | - } |
454 | | - chunks.append(chunk_dict) |
455 | | - |
456 | | - # Stage 2: Rerank with cross-encoder (AC2) |
457 | | - if use_reranker and chunks: |
458 | | - rerank_start = time.time() |
459 | | - chunks = await rerank_chunks( |
460 | | - final_query, chunks, top_k=RETRIEVAL_RERANKER_TOP_K |
461 | | - ) |
462 | | - rerank_latency = (time.time() - rerank_start) * 1000 |
463 | | - logger.info( |
464 | | - "Reranker latency: %.1fms for %d chunks", rerank_latency, len(chunks) |
465 | | - ) |
466 | | - |
467 | | - # Stage 3: MMR diversity selection (AC3) |
468 | | - if use_mmr and chunks and query_embedding: |
469 | | - chunk_embeddings = [c.get("embedding", []) for c in chunks] |
470 | | - # Filter out chunks without embeddings |
471 | | - valid_chunks = [(c, e) for c, e in zip(chunks, chunk_embeddings) if len(e) > 0] |
472 | | - if valid_chunks: |
473 | | - valid_chunk_list = [c for c, _ in valid_chunks] |
474 | | - valid_embeddings = [e for _, e in valid_chunks] |
475 | | - chunks = mmr_select( |
476 | | - query_embedding=query_embedding, |
477 | | - chunk_embeddings=valid_embeddings, |
478 | | - chunks=valid_chunk_list, |
479 | | - k=RETRIEVAL_MMR_K, |
480 | | - lambda_param=mmr_lambda, |
481 | | - ) |
482 | | - logger.info( |
483 | | - "MMR selected %d diverse chunks (lambda=%.2f)", len(chunks), mmr_lambda |
484 | | - ) |
485 | | - |
486 | | - # Stage 4: Adaptive-K selection (AC4) |
487 | | - if use_adaptive_k and chunks: |
488 | | - chunks, k_metadata = select_adaptive_k( |
489 | | - chunks, |
490 | | - score_key="rerank_score" if use_reranker else "similarity", |
491 | | - k_min=RETRIEVAL_ADAPTIVE_K_MIN, |
492 | | - k_max=RETRIEVAL_ADAPTIVE_K_MAX, |
493 | | - gap_threshold=RETRIEVAL_ADAPTIVE_K_GAP_THRESHOLD, |
494 | | - ) |
495 | | - logger.info( |
496 | | - "Adaptive-K: selected %d chunks (gap_found=%s)", |
497 | | - k_metadata["adaptive_k"], |
498 | | - k_metadata["gap_found"], |
499 | | - ) |
500 | | - elif not use_adaptive_k: |
501 | | - chunks = chunks[:DEFAULT_RETRIEVAL_K] # Fallback to fixed K |
502 | | - |
503 | | - # Clean up: remove embedding from final results (not needed for prompt) |
504 | | - results = [] |
505 | | - for i, chunk in enumerate(chunks): |
506 | | - chunk_copy = {k: v for k, v in chunk.items() if k != "embedding"} |
507 | | - chunk_copy["chunk_label"] = f"Chunk {i + 1}" # 1-based for citation binding |
508 | | - results.append(chunk_copy) |
509 | | - |
510 | | - logger.info("Retrieved %d context chunks after full pipeline", len(results)) |
511 | | - return results |
512 | | - |
513 | | - |
514 | 306 | async def retrieve_context_with_graph( |
515 | 307 | session: AsyncSession, |
516 | 308 | workspace_id: UUID, |
@@ -736,26 +528,15 @@ async def chat_stream( |
736 | 528 | }, |
737 | 529 | ) |
738 | 530 |
|
739 | | - # Retrieve context from documents (AC2) |
740 | | - # Use LangGraph-based pipeline if enabled, otherwise legacy manual orchestration |
741 | | - if USE_LANGGRAPH_RETRIEVAL: |
742 | | - logger.info("Using LangGraph-based retrieval pipeline") |
743 | | - context_chunks = await retrieve_context_with_graph( |
744 | | - session, |
745 | | - workspace_id, |
746 | | - request.message, |
747 | | - request.chat_history, |
748 | | - provider=request.provider, |
749 | | - ) |
750 | | - else: |
751 | | - # Legacy: Pass provider override to use consistent LLM for query processing |
752 | | - context_chunks = await retrieve_context( |
753 | | - session, |
754 | | - workspace_id, |
755 | | - request.message, |
756 | | - request.chat_history, |
757 | | - provider=request.provider, |
758 | | - ) |
| 531 | + # Retrieve context from documents using LangGraph-based pipeline |
| 532 | + logger.info("Using LangGraph-based retrieval pipeline") |
| 533 | + context_chunks = await retrieve_context_with_graph( |
| 534 | + session, |
| 535 | + workspace_id, |
| 536 | + request.message, |
| 537 | + request.chat_history, |
| 538 | + provider=request.provider, |
| 539 | + ) |
759 | 540 |
|
760 | 541 | # Build citation metadata from context chunks |
761 | 542 | citations = build_citation_metadata(context_chunks) |
|
0 commit comments