diff --git a/README.md b/README.md
index bc0affd..fbe2a1b 100644
--- a/README.md
+++ b/README.md
@@ -159,9 +159,35 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)
### 3. Searching and Retrieval-Augmented Generation (RAG)
-#### 3.1 Simple RAG pipeline
+#### 3.1 Dynamically routed RAG
-Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
+Now you can run a dynamically routed RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking a retrieval tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are sent to the `on_retrieval` callback and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the assistant response:
+
+```python
+from raglite import rag
+
+# Create a user message:
+messages = [] # Or start with an existing message history.
+messages.append({
+ "role": "user",
+ "content": "How is intelligence measured?"
+})
+
+# Let the LLM decide whether to search the database by providing a retrieval tool to the LLM.
+# If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history.
+# Finally, assistant response is streamed and appended to the message history.
+chunk_spans = []
+stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config)
+for update in stream:
+ print(update, end="")
+
+# Access the documents referenced in the RAG context:
+documents = [chunk_span.document for chunk_span in chunk_spans]
+```
+
+#### 3.2 Programmable RAG
+
+If you need manual control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
```python
from raglite import create_rag_instruction, rag, retrieve_rag_context
@@ -174,21 +200,19 @@ chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_co
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
-# Stream the RAG response:
+# Stream the RAG response and append it to the message history:
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")
-# Access the documents cited in the RAG response:
+# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```
-#### 3.2 Advanced RAG pipeline
-
> [!TIP]
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.
-In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:
+RAGLite also offers more advanced control over the individual steps of a full RAG pipeline:
1. Searching for relevant chunks with keyword, vector, or hybrid search
2. Retrieving the chunks from the database
@@ -229,14 +253,14 @@ from raglite import create_rag_instruction
messages = [] # Or start with an existing message history.
messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
-# Stream the RAG response:
+# Stream the RAG response and append it to the message history:
from raglite import rag
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")
-# Access the documents cited in the RAG response:
+# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```
diff --git a/poetry.lock b/poetry.lock
index 5626d81..7353b43 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -2547,13 +2547,13 @@ test = ["coverage", "pytest", "pytest-cov"]
[[package]]
name = "litellm"
-version = "1.47.1"
+version = "1.48.10"
description = "Library to easily interface with LLM API providers"
optional = false
python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8"
files = [
- {file = "litellm-1.47.1-py3-none-any.whl", hash = "sha256:baa1961287ee398c937e8a5ecd1fcb821ea8f91cbd1f4757b6c19d7bcc84d4fd"},
- {file = "litellm-1.47.1.tar.gz", hash = "sha256:51d1eb353573ddeac75c45b66147f533f64f231540667ea30b63edb9a2af15ce"},
+ {file = "litellm-1.48.10-py3-none-any.whl", hash = "sha256:752efd59747a0895f4695d025c66f0b2258d80a61175f7cfa41dbe4894ef95e1"},
+ {file = "litellm-1.48.10.tar.gz", hash = "sha256:0a4ff75da78e66baeae0658ad8de498298310a5efda74c3d840ce2b013e8401d"},
]
[package.dependencies]
@@ -6813,4 +6813,4 @@ ragas = ["ragas"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<4.0"
-content-hash = "ff8a8596ac88ae5406234f08810e2e4d654714af0aa3663451e988a2cf6ef51e"
+content-hash = "b3a14066711fe4caec356d0aa18514495d44ac253371d2560fc0c5aea890aaef"
diff --git a/pyproject.toml b/pyproject.toml
index 3cecd49..10e3f57 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -32,7 +32,7 @@ scipy = ">=1.5.0"
spacy = ">=3.7.0,<3.8.0"
# Large Language Models:
huggingface-hub = ">=0.22.0"
-litellm = ">=1.47.1"
+litellm = ">=1.48.4"
llama-cpp-python = ">=0.3.2"
pydantic = ">=2.7.0"
# Approximate Nearest Neighbors:
diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py
index 1f3eeeb..7dd53ef 100644
--- a/src/raglite/_chainlit.py
+++ b/src/raglite/_chainlit.py
@@ -6,22 +6,11 @@
import chainlit as cl
from chainlit.input_widget import Switch, TextInput
-from raglite import (
- RAGLiteConfig,
- async_rag,
- create_rag_instruction,
- hybrid_search,
- insert_document,
- rerank_chunks,
- retrieve_chunk_spans,
- retrieve_chunks,
-)
+from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks
from raglite._markdown import document_to_markdown
async_insert_document = cl.make_async(insert_document)
async_hybrid_search = cl.make_async(hybrid_search)
-async_retrieve_chunks = cl.make_async(retrieve_chunks)
-async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans)
async_rerank_chunks = cl.make_async(rerank_chunks)
@@ -93,31 +82,27 @@ async def handle_message(user_message: cl.Message) -> None:
for i, attachment in enumerate(inline_attachments)
)
+ f"\n\n{user_message.content}"
- )
- # Search for relevant contexts for RAG.
- async with cl.Step(name="search", type="retrieval") as step:
- step.input = user_message.content
- chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config)
- chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config)
- step.output = chunks
- step.elements = [ # Show the top chunks inline.
- cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5]
- ]
- await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
- # Rerank the chunks and group them into chunk spans.
- async with cl.Step(name="rerank", type="rerank") as step:
- step.input = chunks
- chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config)
- chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config)
- step.output = chunk_spans
- step.elements = [ # Show the top chunk spans inline.
- cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans
- ]
- await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602.
+ ).strip()
# Stream the LLM response.
assistant_message = cl.Message(content="")
+ chunk_spans = []
messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call]
- messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans))
- async for token in async_rag(messages, config=config):
+ messages.append({"role": "user", "content": user_prompt})
+ async for token in async_rag(
+ messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config
+ ):
await assistant_message.stream_token(token)
+ # Append RAG sources, if any.
+ if chunk_spans:
+ rag_sources: dict[str, list[str]] = {}
+ for chunk_span in chunk_spans:
+ rag_sources.setdefault(chunk_span.document.id, [])
+ rag_sources[chunk_span.document.id].append(str(chunk_span))
+ assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks.
+ f"[{i + 1}]" for i in range(len(rag_sources))
+ )
+ assistant_message.elements = [ # Markdown content is rendered in sidebar.
+ cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc]
+ for i, (_, content) in enumerate(rag_sources.items())
+ ]
await assistant_message.update() # type: ignore[no-untyped-call]
diff --git a/src/raglite/_database.py b/src/raglite/_database.py
index 95fc88b..573a3cc 100644
--- a/src/raglite/_database.py
+++ b/src/raglite/_database.py
@@ -169,18 +169,41 @@ def to_xml(self, index: int | None = None) -> str:
if not self.chunks:
return ""
index_attribute = f' index="{index}"' if index is not None else ""
- xml = "\n".join(
+ xml_document = "\n".join(
[
f'',
f"{self.document.url if self.document.url else self.document.filename}",
- f'',
- f"\n{escape(self.chunks[0].headings.strip())}\n",
+ f'',
+ f"\n{escape(self.chunks[0].headings.strip())}\n",
f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n",
"",
"",
]
)
- return xml
+ return xml_document
+
+ def to_json(self, index: int | None = None) -> str:
+ """Convert this chunk span to a JSON representation.
+
+ The JSON representation follows Anthropic's best practices [1].
+
+ [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
+ """
+ if not self.chunks:
+ return "{}"
+ index_attribute = {"index": index} if index is not None else {}
+ json_document = {
+ **index_attribute,
+ "id": self.document.id,
+ "source": self.document.url if self.document.url else self.document.filename,
+ "span": {
+ "from_chunk_id": self.chunks[0].id,
+ "to_chunk_id": self.chunks[-1].id,
+ "headings": self.chunks[0].headings.strip(),
+ "content": "".join(chunk.body for chunk in self.chunks).strip(),
+ },
+ }
+ return json.dumps(json_document)
@property
def content(self) -> str:
diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py
index bd85d47..c902e68 100644
--- a/src/raglite/_extract.py
+++ b/src/raglite/_extract.py
@@ -34,13 +34,12 @@ class MyNameResponse(BaseModel):
# Load the default config if not provided.
config = config or RAGLiteConfig()
# Check if the LLM supports the response format.
- llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
llm_supports_response_format = "response_format" in (
- get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or []
+ get_supported_openai_params(model=config.llm) or []
)
# Update the system prompt with the JSON schema of the return type to help the LLM.
system_prompt = getattr(return_type, "system_prompt", "").strip()
- if not llm_supports_response_format or llm_provider == "llama-cpp-python":
+ if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"):
system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}"
# Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode
# is disabled by default because it only supports a subset of JSON schema features [2].
diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py
index 1135e97..e0bab3b 100644
--- a/src/raglite/_litellm.py
+++ b/src/raglite/_litellm.py
@@ -13,6 +13,8 @@
import httpx
import litellm
from litellm import ( # type: ignore[attr-defined]
+ ChatCompletionToolCallChunk,
+ ChatCompletionToolCallFunctionChunk,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
@@ -20,6 +22,7 @@
get_model_info,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
+from litellm.utils import custom_llm_setup
from llama_cpp import ( # type: ignore[attr-defined]
ChatCompletionRequestMessage,
CreateChatCompletionResponse,
@@ -31,6 +34,7 @@
from raglite._config import RAGLiteConfig
# Reduce the logging level for LiteLLM, flashrank, and httpx.
+litellm.suppress_debug_info = True
os.environ["LITELLM_LOG"] = "WARNING"
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("flashrank").setLevel(logging.WARNING)
@@ -112,6 +116,8 @@ def llm(model: str, **kwargs: Any) -> Llama:
n_ctx=n_ctx,
n_gpu_layers=-1,
verbose=False,
+ # Enable function calling.
+ chat_format="chatml-function-calling",
# Workaround to enable long context embedding models [1].
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
n_batch=n_ctx if n_ctx > 0 else 1024,
@@ -121,24 +127,23 @@ def llm(model: str, **kwargs: Any) -> Llama:
# Enable caching.
llm.set_cache(LlamaRAMCache())
# Register the model info with LiteLLM.
- litellm.register_model( # type: ignore[attr-defined]
- {
- model: {
- "max_tokens": llm.n_ctx(),
- "max_input_tokens": llm.n_ctx(),
- "max_output_tokens": None,
- "input_cost_per_token": 0.0,
- "output_cost_per_token": 0.0,
- "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None,
- "litellm_provider": "llama-cpp-python",
- "mode": "embedding" if kwargs.get("embedding") else "completion",
- "supported_openai_params": LlamaCppPythonLLM.supported_openai_params,
- "supports_function_calling": True,
- "supports_parallel_function_calling": True,
- "supports_vision": False,
- }
+ model_info = {
+ repo_id_filename: {
+ "max_tokens": llm.n_ctx(),
+ "max_input_tokens": llm.n_ctx(),
+ "max_output_tokens": None,
+ "input_cost_per_token": 0.0,
+ "output_cost_per_token": 0.0,
+ "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None,
+ "litellm_provider": "llama-cpp-python",
+ "mode": "embedding" if kwargs.get("embedding") else "completion",
+ "supported_openai_params": LlamaCppPythonLLM.supported_openai_params,
+ "supports_function_calling": True,
+ "supports_parallel_function_calling": True,
+ "supports_vision": False,
}
- )
+ }
+ litellm.register_model(model_info) # type: ignore[attr-defined]
return llm
def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, Any]:
@@ -218,24 +223,40 @@ def streaming( # noqa: PLR0913
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),
)
for chunk in stream:
- choices = chunk.get("choices", [])
- for choice in choices:
- text = choice.get("delta", {}).get("content", None)
- finish_reason = choice.get("finish_reason")
- litellm_generic_streaming_chunk = GenericStreamingChunk(
- text=text, # type: ignore[typeddict-item]
- is_finished=bool(finish_reason),
- finish_reason=finish_reason, # type: ignore[typeddict-item]
- usage=None,
- index=choice.get("index"), # type: ignore[typeddict-item]
- provider_specific_fields={
- "id": chunk.get("id"),
- "model": chunk.get("model"),
- "created": chunk.get("created"),
- "object": chunk.get("object"),
- },
+ choices = chunk.get("choices")
+ if not choices:
+ continue
+ text = choices[0].get("delta", {}).get("content", None)
+ tool_calls = choices[0].get("delta", {}).get("tool_calls", None)
+ tool_use = (
+ ChatCompletionToolCallChunk(
+ id=tool_calls[0]["id"], # type: ignore[index]
+ type="function",
+ function=ChatCompletionToolCallFunctionChunk(
+ name=tool_calls[0]["function"]["name"], # type: ignore[index]
+ arguments=tool_calls[0]["function"]["arguments"], # type: ignore[index]
+ ),
+ index=tool_calls[0]["index"], # type: ignore[index]
)
- yield litellm_generic_streaming_chunk
+ if tool_calls
+ else None
+ )
+ finish_reason = choices[0].get("finish_reason")
+ litellm_generic_streaming_chunk = GenericStreamingChunk(
+ text=text, # type: ignore[typeddict-item]
+ tool_use=tool_use,
+ is_finished=bool(finish_reason),
+ finish_reason=finish_reason, # type: ignore[typeddict-item]
+ usage=None,
+ index=choices[0].get("index"), # type: ignore[typeddict-item]
+ provider_specific_fields={
+ "id": chunk.get("id"),
+ "model": chunk.get("model"),
+ "created": chunk.get("created"),
+ "object": chunk.get("object"),
+ },
+ )
+ yield litellm_generic_streaming_chunk
async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
self,
@@ -287,7 +308,7 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
litellm.custom_provider_map.append(
{"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()}
)
- litellm.suppress_debug_info = True
+ custom_llm_setup() # type: ignore[no-untyped-call]
@cache
@@ -298,8 +319,7 @@ def get_context_size(config: RAGLiteConfig, *, fallback: int = 2048) -> int:
if config.llm.startswith("llama-cpp-python"):
_ = LlamaCppPythonLLM.llm(config.llm)
# Attempt to read the context size from LiteLLM's model info.
- llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None
- model_info = get_model_info(config.llm, custom_llm_provider=llm_provider)
+ model_info = get_model_info(config.llm)
max_tokens = model_info.get("max_tokens")
if isinstance(max_tokens, int) and max_tokens > 0:
return max_tokens
@@ -322,8 +342,7 @@ def get_embedding_dim(config: RAGLiteConfig, *, fallback: bool = True) -> int:
if config.embedder.startswith("llama-cpp-python"):
_ = LlamaCppPythonLLM.llm(config.embedder, embedding=True)
# Attempt to read the embedding dimension from LiteLLM's model info.
- llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None
- model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider)
+ model_info = get_model_info(config.embedder)
embedding_dim = model_info.get("output_vector_size")
if isinstance(embedding_dim, int) and embedding_dim > 0:
return embedding_dim
diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index 8fb1a0c..da52e23 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -1,9 +1,17 @@
"""Retrieval-augmented generation."""
-from collections.abc import AsyncIterator, Iterator
+import json
+from collections.abc import AsyncIterator, Callable, Iterator
+from typing import Any
import numpy as np
-from litellm import acompletion, completion
+from litellm import ( # type: ignore[attr-defined]
+ ChatCompletionMessageToolCall,
+ acompletion,
+ completion,
+ stream_chunk_builder,
+ supports_function_calling,
+)
from raglite._config import RAGLiteConfig
from raglite._database import ChunkSpan
@@ -70,25 +78,214 @@ def create_rag_instruction(
return message
-def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]:
- # Truncate the oldest messages so we don't hit the context limit.
+def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]:
+ """Left clip a messages array to avoid hitting the context limit."""
+ cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1])
+ first_message = -np.searchsorted(cum_tokens, max_tokens)
+ return messages[first_message:]
+
+
+def _get_tools(
+ messages: list[dict[str, str]], config: RAGLiteConfig
+) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]:
+ """Get tools to search the knowledge base if no RAG context is provided in the messages."""
+ # Check if messages already contain RAG context or if the LLM supports tool use.
+ final_message = messages[-1].get("content", "")
+ messages_contain_rag_context = any(s in final_message for s in ("", "from_chunk_id"))
+ llm_supports_function_calling = supports_function_calling(config.llm)
+ if not messages_contain_rag_context and not llm_supports_function_calling:
+ error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling."
+ raise ValueError(error_message)
+ # Add a tool to search the knowledge base if no RAG context is provided in the messages. Because
+ # llama-cpp-python cannot stream tool_use='auto' yet, we use a workaround that forces the LLM
+ # to use a tool, but allows it to skip the search.
+ auto_tool_use_workaround = (
+ {
+ "expert": {
+ "type": "boolean",
+ "description": "The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.",
+ }
+ }
+ if config.llm.startswith("llama-cpp-python")
+ else {}
+ )
+ tools: list[dict[str, Any]] | None = (
+ [
+ {
+ "type": "function",
+ "function": {
+ "name": "search_knowledge_base",
+ "description": "Search the knowledge base. IMPORTANT: Only use this tool if a well-rounded non-expert would need to look up information to answer the question.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ **auto_tool_use_workaround,
+ "query": {
+ "type": "string",
+ "description": (
+ "The `query` string to search the knowledge base with.\n"
+ "The `query` string MUST satisfy ALL of the following criteria:\n"
+ "- The `query` string MUST be a precise question in the user's language.\n"
+ "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history."
+ ),
+ },
+ },
+ "required": [*list(auto_tool_use_workaround), "query"],
+ "additionalProperties": False,
+ },
+ },
+ }
+ ]
+ if not messages_contain_rag_context
+ else None
+ )
+ tool_choice: dict[str, Any] | str | None = (
+ (
+ {"type": "function", "function": {"name": "search_knowledge_base"}}
+ if auto_tool_use_workaround
+ else "auto"
+ )
+ if tools
+ else None
+ )
+ return tools, tool_choice
+
+
+def _run_tools(
+ tool_calls: list[ChatCompletionMessageToolCall],
+ on_retrieval: Callable[[list[ChunkSpan]], None] | None,
+ config: RAGLiteConfig,
+) -> list[dict[str, Any]]:
+ """Run tools to search the knowledge base for RAG context."""
+ tool_messages: list[dict[str, Any]] = []
+ for tool_call in tool_calls:
+ if tool_call.function.name == "search_knowledge_base":
+ kwargs = json.loads(tool_call.function.arguments)
+ kwargs["config"] = config
+ skip = not kwargs.pop("expert", True)
+ chunk_spans = retrieve_rag_context(**kwargs) if not skip and kwargs["query"] else None
+ tool_messages.append(
+ {
+ "role": "tool",
+ "content": '{{"documents": [{elements}]}}'.format(
+ elements=", ".join(
+ chunk_span.to_json(index=i + 1)
+ for i, chunk_span in enumerate(chunk_spans) # type: ignore[arg-type]
+ )
+ )
+ if not skip and kwargs["query"]
+ else "{}",
+ "tool_call_id": tool_call.id,
+ }
+ )
+ if chunk_spans and callable(on_retrieval):
+ on_retrieval(chunk_spans)
+ else:
+ error_message = f"Unknown function `{tool_call.function.name}`."
+ raise ValueError(error_message)
+ return tool_messages
+
+
+def rag(
+ messages: list[dict[str, str]],
+ *,
+ on_retrieval: Callable[[list[ChunkSpan]], None] | None = None,
+ config: RAGLiteConfig,
+) -> Iterator[str]:
+ # If the final message does not contain RAG context, get a tool to search the knowledge base.
max_tokens = get_context_size(config)
- cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
- messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
- # Stream the LLM response.
- stream = completion(model=config.llm, messages=messages, stream=True)
- for output in stream:
- token: str = output["choices"][0]["delta"].get("content") or ""
- yield token
+ tools, tool_choice = _get_tools(messages, config)
+ # Stream the LLM response, which is either a tool call request or an assistant response.
+ chunks = []
+ clipped_messages = _clip(messages, max_tokens)
+ if tools and config.llm.startswith("llama-cpp-python"):
+ # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
+ clipped_messages[-1]["content"] += (
+ "\n\n\n"
+ f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
+ "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n"
+ ""
+ )
+ stream = completion(
+ model=config.llm,
+ messages=clipped_messages,
+ tools=tools,
+ tool_choice=tool_choice,
+ stream=True,
+ )
+ for chunk in stream:
+ chunks.append(chunk)
+ if isinstance(token := chunk.choices[0].delta.content, str):
+ yield token
+ # Check if there are tools to be called.
+ response = stream_chunk_builder(chunks, messages)
+ tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr]
+ if tool_calls:
+ # Add the tool call request to the message array.
+ messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
+ # Run the tool calls to retrieve the RAG context and append the output to the message array.
+ messages.extend(_run_tools(tool_calls, on_retrieval, config))
+ # Stream the assistant response.
+ chunks = []
+ stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True)
+ for chunk in stream:
+ chunks.append(chunk)
+ if isinstance(token := chunk.choices[0].delta.content, str):
+ yield token
+ # Append the assistant response to the message array.
+ response = stream_chunk_builder(chunks, messages)
+ messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
-async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]:
- # Truncate the oldest messages so we don't hit the context limit.
+async def async_rag(
+ messages: list[dict[str, str]],
+ *,
+ on_retrieval: Callable[[list[ChunkSpan]], None] | None = None,
+ config: RAGLiteConfig,
+) -> AsyncIterator[str]:
+ # If the final message does not contain RAG context, get a tool to search the knowledge base.
max_tokens = get_context_size(config)
- cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1])
- messages = messages[-np.searchsorted(cum_tokens, max_tokens) :]
- # Asynchronously stream the LLM response.
- async_stream = await acompletion(model=config.llm, messages=messages, stream=True)
- async for output in async_stream:
- token: str = output["choices"][0]["delta"].get("content") or ""
- yield token
+ tools, tool_choice = _get_tools(messages, config)
+ # Asynchronously stream the LLM response, which is either a tool call or an assistant response.
+ chunks = []
+ clipped_messages = _clip(messages, max_tokens)
+ if tools and config.llm.startswith("llama-cpp-python"):
+ # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
+ clipped_messages[-1]["content"] += (
+ "\n\n\n"
+ f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
+ "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n"
+ ""
+ )
+ async_stream = await acompletion(
+ model=config.llm,
+ messages=clipped_messages,
+ tools=tools,
+ tool_choice=tool_choice,
+ stream=True,
+ )
+ async for chunk in async_stream:
+ chunks.append(chunk)
+ if isinstance(token := chunk.choices[0].delta.content, str):
+ yield token
+ # Check if there are tools to be called.
+ response = stream_chunk_builder(chunks, messages)
+ tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr]
+ if tool_calls:
+ # Add the tool call requests to the message array.
+ messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
+ # Run the tool calls to retrieve the RAG context and append the output to the message array.
+ # TODO: Make this async.
+ messages.extend(_run_tools(tool_calls, on_retrieval, config))
+ # Asynchronously stream the assistant response.
+ chunks = []
+ async_stream = await acompletion(
+ model=config.llm, messages=_clip(messages, max_tokens), stream=True
+ )
+ async for chunk in async_stream:
+ chunks.append(chunk)
+ if isinstance(token := chunk.choices[0].delta.content, str):
+ yield token
+ # Append the assistant response to the message array.
+ response = stream_chunk_builder(chunks, messages)
+ messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr]
diff --git a/src/raglite/_search.py b/src/raglite/_search.py
index c109324..54fec07 100644
--- a/src/raglite/_search.py
+++ b/src/raglite/_search.py
@@ -1,5 +1,6 @@
"""Search and retrieve chunks."""
+import contextlib
import re
import string
from collections import defaultdict
@@ -8,7 +9,7 @@
from typing import cast
import numpy as np
-from langdetect import detect
+from langdetect import LangDetectException, detect
from sqlalchemy.engine import make_url
from sqlalchemy.orm import joinedload
from sqlmodel import Session, and_, col, or_, select, text
@@ -212,8 +213,9 @@ def rerank_chunks(
# Select the reranker.
if isinstance(config.reranker, Sequence):
# Detect the languages of the chunks and queries.
- langs = {detect(str(chunk)) for chunk in chunks}
- langs.add(detect(query))
+ with contextlib.suppress(LangDetectException):
+ langs = {detect(str(chunk)) for chunk in chunks}
+ langs.add(detect(query))
# If all chunks and the query are in the same language, use a language-specific reranker.
rerankers = dict(config.reranker)
if len(langs) == 1 and (lang := next(iter(langs))) in rerankers:
diff --git a/tests/conftest.py b/tests/conftest.py
index ef41889..943250c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -7,6 +7,7 @@
from pathlib import Path
import pytest
+from llama_cpp import llama_supports_gpu_offload
from sqlalchemy import create_engine, text
from raglite import RAGLiteConfig, insert_document
@@ -23,6 +24,11 @@ def is_postgres_running() -> bool:
return False
+def is_accelerator_available() -> bool:
+ """Check if an accelerator is available."""
+ return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
+
+
def is_openai_available() -> bool:
"""Check if an OpenAI API key is set."""
return bool(os.environ.get("OPENAI_API_KEY"))
@@ -69,24 +75,41 @@ def database(request: pytest.FixtureRequest) -> str:
scope="session",
params=[
pytest.param(
- "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance.
- id="bge_m3",
+ (
+ "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096",
+ "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance.
+ ),
+ id="llama32_3B-bge_m3",
),
pytest.param(
- "text-embedding-3-small",
- id="openai_text_embedding_3_small",
+ ("gpt-4o-mini", "text-embedding-3-small"),
+ id="gpt_4o_mini-text_embedding_3_small",
marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
),
],
)
-def embedder(request: pytest.FixtureRequest) -> str:
- """Get an embedder model URL to test RAGLite with."""
- embedder: str = request.param
+def llm_embedder(request: pytest.FixtureRequest) -> str:
+ """Get an LLM and embedder pair to test RAGLite with."""
+ llm_embedder: str = request.param
+ return llm_embedder
+
+
+@pytest.fixture(scope="session")
+def llm(llm_embedder: tuple[str, str]) -> str:
+ """Get an LLM to test RAGLite with."""
+ llm, _ = llm_embedder
+ return llm
+
+
+@pytest.fixture(scope="session")
+def embedder(llm_embedder: tuple[str, str]) -> str:
+ """Get an embedder to test RAGLite with."""
+ _, embedder = llm_embedder
return embedder
@pytest.fixture(scope="session")
-def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
+def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig:
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
# Select the database based on the embedder.
variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
@@ -95,7 +118,7 @@ def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
elif "sqlite" in database:
database = database.replace(".sqlite", f"_{variant}.sqlite")
# Create a RAGLite config for the given database and embedder.
- db_config = RAGLiteConfig(db_url=database, embedder=embedder)
+ db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder)
# Insert a document and update the index.
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper.
insert_document(doc_path, config=db_config)
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 33ef6e0..9cf8dd1 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -9,18 +9,6 @@
from raglite._extract import extract_with_llm
-@pytest.fixture(
- params=[
- pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"),
- pytest.param("gpt-4o-mini", id="openai"),
- ]
-)
-def llm(request: pytest.FixtureRequest) -> str:
- """Get an LLM to test RAGLite with."""
- llm: str = request.param
- return llm
-
-
@pytest.mark.parametrize(
"strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")]
)
@@ -41,7 +29,7 @@ class LoginResponse(BaseModel):
# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(
- LoginResponse, f"{username} // {password}", strict=strict, config=config
+ LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
diff --git a/tests/test_rag.py b/tests/test_rag.py
index 7643bcf..ff0e0ab 100644
--- a/tests/test_rag.py
+++ b/tests/test_rag.py
@@ -1,28 +1,20 @@
"""Test RAGLite's RAG functionality."""
-import os
-
-import pytest
-from llama_cpp import llama_supports_gpu_offload
+import json
from raglite import (
RAGLiteConfig,
create_rag_instruction,
retrieve_rag_context,
)
+from raglite._database import ChunkSpan
from raglite._rag import rag
-def is_accelerator_available() -> bool:
- """Check if an accelerator is available."""
- return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
-
-
-@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
-def test_rag(raglite_test_config: RAGLiteConfig) -> None:
- """Test Retrieval-Augmented Generation."""
- # Answer a question with RAG.
- user_prompt = "What does it mean for two events to be simultaneous?"
+def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None:
+ """Test Retrieval-Augmented Generation with manual retrieval."""
+ # Answer a question with manual RAG.
+ user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config)
messages = [create_rag_instruction(user_prompt, context=chunk_spans)]
stream = rag(messages, config=raglite_test_config)
@@ -30,4 +22,48 @@ def test_rag(raglite_test_config: RAGLiteConfig) -> None:
for update in stream:
assert isinstance(update, str)
answer += update
- assert "simultaneous" in answer.lower()
+ assert "event" in answer.lower()
+ # Verify that no RAG context was retrieved through tool use.
+ assert [message["role"] for message in messages] == ["user", "assistant"]
+
+
+def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
+ """Test Retrieval-Augmented Generation with automatic retrieval."""
+ # Answer a question that requires RAG.
+ user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?"
+ messages = [{"role": "user", "content": user_prompt}]
+ chunk_spans = []
+ stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
+ answer = ""
+ for update in stream:
+ assert isinstance(update, str)
+ answer += update
+ assert "event" in answer.lower()
+ # Verify that RAG context was retrieved automatically.
+ assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"]
+ assert json.loads(messages[-2]["content"])
+ assert chunk_spans
+ assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)
+
+
+def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
+ """Test Retrieval-Augmented Generation with automatic retrieval."""
+ # Answer a question that does not require RAG.
+ user_prompt = "Is 7 a prime number? Answer with Yes or No only."
+ messages = [{"role": "user", "content": user_prompt}]
+ chunk_spans = []
+ stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config)
+ answer = ""
+ for update in stream:
+ assert isinstance(update, str)
+ answer += update
+ assert "yes" in answer.lower()
+ # Verify that no RAG context was retrieved.
+ if raglite_test_config.llm.startswith("llama-cpp-python"):
+ # Llama.cpp does not support streaming tool_choice="auto" yet, so instead we verify that the
+ # LLM indicates that the tool call request may be skipped by checking that content is empty.
+ assert [msg["role"] for msg in messages] == ["user", "assistant", "tool", "assistant"]
+ assert not json.loads(messages[-2]["content"])
+ else:
+ assert [msg["role"] for msg in messages] == ["user", "assistant"]
+ assert not chunk_spans