From 319f52288cc524353d933126b8d192c0dfc461b0 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 17:04:19 +0100 Subject: [PATCH 01/11] feat: let LLM choose whether to retrieve context --- README.md | 39 +++++-- poetry.lock | 8 +- pyproject.toml | 2 +- src/raglite/_database.py | 31 +++++- src/raglite/_litellm.py | 54 ++++++---- src/raglite/_rag.py | 214 +++++++++++++++++++++++++++++++++++---- src/raglite/_search.py | 8 +- tests/conftest.py | 57 +++++++++-- tests/test_rag.py | 57 ++++++++--- 9 files changed, 395 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index bc0affd..3d2bada 100644 --- a/README.md +++ b/README.md @@ -159,9 +159,36 @@ insert_document(Path("Special Relativity.pdf"), config=my_config) ### 3. Searching and Retrieval-Augmented Generation (RAG) -#### 3.1 Simple RAG pipeline +#### 3.1 Minimal RAG pipeline -Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: +Now you can run a minimal RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking it as a tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans. The retrieval results are appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the response: + +```python +from raglite import rag + +# Create a user message: +messages = [] # Or start with an existing message history. +messages.append({ + "role": "user", + "content": "How is intelligence measured?" +}) + +# Let the LLM decide whether to search the database by providing a search method as a tool to the LLM. +# If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history. +# Finally, LLM response is streamed and appended to the message history. +stream = rag(messages, config=my_config) +for update in stream: + print(update, end="") + +# Access the RAG context appended to the message history: +import json + +context = [json.loads(message["content"]) for message in messages if message["role"] == "tool"] +``` + +#### 3.2 Basic RAG pipeline + +If you want control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: ```python from raglite import create_rag_instruction, rag, retrieve_rag_context @@ -179,16 +206,16 @@ stream = rag(messages, config=my_config) for update in stream: print(update, end="") -# Access the documents cited in the RAG response: +# Access the documents referenced in the RAG context: documents = [chunk_span.document for chunk_span in chunk_spans] ``` -#### 3.2 Advanced RAG pipeline +#### 3.3 Advanced RAG pipeline > [!TIP] > 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks. -In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps: +In addition to the basic RAG pipeline, RAGLite also offers more advanced control over the pipeline. A full pipeline consists of several steps: 1. Searching for relevant chunks with keyword, vector, or hybrid search 2. Retrieving the chunks from the database @@ -236,7 +263,7 @@ stream = rag(messages, config=my_config) for update in stream: print(update, end="") -# Access the documents cited in the RAG response: +# Access the documents referenced in the RAG context: documents = [chunk_span.document for chunk_span in chunk_spans] ``` diff --git a/poetry.lock b/poetry.lock index 5626d81..7353b43 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2547,13 +2547,13 @@ test = ["coverage", "pytest", "pytest-cov"] [[package]] name = "litellm" -version = "1.47.1" +version = "1.48.10" description = "Library to easily interface with LLM API providers" optional = false python-versions = "!=2.7.*,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,!=3.7.*,>=3.8" files = [ - {file = "litellm-1.47.1-py3-none-any.whl", hash = "sha256:baa1961287ee398c937e8a5ecd1fcb821ea8f91cbd1f4757b6c19d7bcc84d4fd"}, - {file = "litellm-1.47.1.tar.gz", hash = "sha256:51d1eb353573ddeac75c45b66147f533f64f231540667ea30b63edb9a2af15ce"}, + {file = "litellm-1.48.10-py3-none-any.whl", hash = "sha256:752efd59747a0895f4695d025c66f0b2258d80a61175f7cfa41dbe4894ef95e1"}, + {file = "litellm-1.48.10.tar.gz", hash = "sha256:0a4ff75da78e66baeae0658ad8de498298310a5efda74c3d840ce2b013e8401d"}, ] [package.dependencies] @@ -6813,4 +6813,4 @@ ragas = ["ragas"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<4.0" -content-hash = "ff8a8596ac88ae5406234f08810e2e4d654714af0aa3663451e988a2cf6ef51e" +content-hash = "b3a14066711fe4caec356d0aa18514495d44ac253371d2560fc0c5aea890aaef" diff --git a/pyproject.toml b/pyproject.toml index 3cecd49..10e3f57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ scipy = ">=1.5.0" spacy = ">=3.7.0,<3.8.0" # Large Language Models: huggingface-hub = ">=0.22.0" -litellm = ">=1.47.1" +litellm = ">=1.48.4" llama-cpp-python = ">=0.3.2" pydantic = ">=2.7.0" # Approximate Nearest Neighbors: diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 95fc88b..573a3cc 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -169,18 +169,41 @@ def to_xml(self, index: int | None = None) -> str: if not self.chunks: return "" index_attribute = f' index="{index}"' if index is not None else "" - xml = "\n".join( + xml_document = "\n".join( [ f'', f"{self.document.url if self.document.url else self.document.filename}", - f'', - f"\n{escape(self.chunks[0].headings.strip())}\n", + f'', + f"\n{escape(self.chunks[0].headings.strip())}\n", f"\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n", "", "", ] ) - return xml + return xml_document + + def to_json(self, index: int | None = None) -> str: + """Convert this chunk span to a JSON representation. + + The JSON representation follows Anthropic's best practices [1]. + + [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips + """ + if not self.chunks: + return "{}" + index_attribute = {"index": index} if index is not None else {} + json_document = { + **index_attribute, + "id": self.document.id, + "source": self.document.url if self.document.url else self.document.filename, + "span": { + "from_chunk_id": self.chunks[0].id, + "to_chunk_id": self.chunks[-1].id, + "headings": self.chunks[0].headings.strip(), + "content": "".join(chunk.body for chunk in self.chunks).strip(), + }, + } + return json.dumps(json_document) @property def content(self) -> str: diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 1135e97..31bf279 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -13,6 +13,8 @@ import httpx import litellm from litellm import ( # type: ignore[attr-defined] + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, CustomLLM, GenericStreamingChunk, ModelResponse, @@ -112,6 +114,8 @@ def llm(model: str, **kwargs: Any) -> Llama: n_ctx=n_ctx, n_gpu_layers=-1, verbose=False, + # Enable function calling. + chat_format="chatml-function-calling", # Workaround to enable long context embedding models [1]. # [1] https://github.com/abetlen/llama-cpp-python/issues/1762 n_batch=n_ctx if n_ctx > 0 else 1024, @@ -218,24 +222,40 @@ def streaming( # noqa: PLR0913 llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True), ) for chunk in stream: - choices = chunk.get("choices", []) - for choice in choices: - text = choice.get("delta", {}).get("content", None) - finish_reason = choice.get("finish_reason") - litellm_generic_streaming_chunk = GenericStreamingChunk( - text=text, # type: ignore[typeddict-item] - is_finished=bool(finish_reason), - finish_reason=finish_reason, # type: ignore[typeddict-item] - usage=None, - index=choice.get("index"), # type: ignore[typeddict-item] - provider_specific_fields={ - "id": chunk.get("id"), - "model": chunk.get("model"), - "created": chunk.get("created"), - "object": chunk.get("object"), - }, + choices = chunk.get("choices") + if not choices: + continue + text = choices[0].get("delta", {}).get("content", None) + tool_calls = choices[0].get("delta", {}).get("tool_calls", None) + tool_use = ( + ChatCompletionToolCallChunk( + id=tool_calls[0]["id"], # type: ignore[index] + type="function", + function=ChatCompletionToolCallFunctionChunk( + name=tool_calls[0]["function"]["name"], # type: ignore[index] + arguments=tool_calls[0]["function"]["arguments"], # type: ignore[index] + ), + index=tool_calls[0]["index"], # type: ignore[index] ) - yield litellm_generic_streaming_chunk + if tool_calls + else None + ) + finish_reason = choices[0].get("finish_reason") + litellm_generic_streaming_chunk = GenericStreamingChunk( + text=text, # type: ignore[typeddict-item] + tool_use=tool_use, + is_finished=bool(finish_reason), + finish_reason=finish_reason, # type: ignore[typeddict-item] + usage=None, + index=choices[0].get("index"), # type: ignore[typeddict-item] + provider_specific_fields={ + "id": chunk.get("id"), + "model": chunk.get("model"), + "created": chunk.get("created"), + "object": chunk.get("object"), + }, + ) + yield litellm_generic_streaming_chunk async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 self, diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 8fb1a0c..b371912 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,9 +1,17 @@ """Retrieval-augmented generation.""" +import json from collections.abc import AsyncIterator, Iterator +from typing import Any import numpy as np -from litellm import acompletion, completion +from litellm import ( # type: ignore[attr-defined] + ChatCompletionMessageToolCall, + acompletion, + completion, + stream_chunk_builder, + supports_function_calling, +) from raglite._config import RAGLiteConfig from raglite._database import ChunkSpan @@ -70,25 +78,197 @@ def create_rag_instruction( return message +def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]: + """Left clip a messages array to avoid hitting the context limit.""" + cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1]) + first_message = -np.searchsorted(cum_tokens, max_tokens) + return messages[first_message:] + + +def _get_tools( + messages: list[dict[str, str]], config: RAGLiteConfig +) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: + """Get tools to search the knowledge base if no RAG context is provided in the messages.""" + # Check if messages already contain RAG context or if the LLM supports tool use. + final_message = messages[-1].get("content", "") + messages_contain_rag_context = any(s in final_message for s in ("", "from_chunk_id")) + llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None + llm_supports_function_calling = supports_function_calling(config.llm, llm_provider) + if not messages_contain_rag_context and not llm_supports_function_calling: + error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling." + raise ValueError(error_message) + # Add a tool to search the knowledge base if no RAG context is provided in the messages. Because + # llama-cpp-python cannot stream tool_use='auto' yet, we use a workaround that forces the LLM + # to use a tool, but allows it to skip the search. + auto_tool_use_workaround = ( + { + "skip": { + "type": "boolean", + "description": "True if a satisfactory answer can be provided without the knowledge base, false otherwise.", + } + } + if llm_provider == "llama-cpp-python" + else {} + ) + tools: list[dict[str, Any]] | None = ( + [ + { + "type": "function", + "function": { + "name": "search_knowledge_base", + "description": "Search the knowledge base. Note: only use this tool if not enough information is available to provide an answer.", + "parameters": { + "type": "object", + "properties": { + **auto_tool_use_workaround, + "query": { + "type": ["string", "null"], + "description": "\n".join( # noqa: FLY002 + [ + "The query string to search the knowledge base with.", + "The query string MUST satisfy ALL of the following criteria:" + "- The query string MUST be a precise question in the user's language.", + "- The query string MUST resolve all pronouns to explicit nouns from the conversation history.", + "- The query string MUST be `null` if `skip` is `true`.", + ] + ), + }, + }, + "required": [*list(auto_tool_use_workaround), "query"], + "additionalProperties": False, + }, + }, + } + ] + if not messages_contain_rag_context + else None + ) + tool_choice: dict[str, Any] | str | None = ( + ( + {"type": "function", "function": {"name": "search_knowledge_base"}} + if auto_tool_use_workaround + else "auto" + ) + if tools + else None + ) + return tools, tool_choice + + +def _run_tools( + tool_calls: list[ChatCompletionMessageToolCall], config: RAGLiteConfig +) -> list[dict[str, Any]]: + """Run tools to search the knowledge base for RAG context.""" + tool_messages: list[dict[str, Any]] = [] + for tool_call in tool_calls: + if tool_call.function.name == "search_knowledge_base": + kwargs = json.loads(tool_call.function.arguments) + kwargs["config"] = config + skip = kwargs.pop("skip", False) + tool_messages.append( + { + "role": "tool", + "content": '{{"documents": [{elements}]}}'.format( + elements=", ".join( + chunk_span.to_json(index=i + 1) + for i, chunk_span in enumerate(retrieve_rag_context(**kwargs)) + ) + ) + if not skip and kwargs["query"] + else "{}", + "tool_call_id": tool_call.id, + } + ) + else: + error_message = f"Unknown function `{tool_call.function.name}`." + raise ValueError(error_message) + return tool_messages + + def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]: - # Truncate the oldest messages so we don't hit the context limit. + # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) - cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1]) - messages = messages[-np.searchsorted(cum_tokens, max_tokens) :] - # Stream the LLM response. - stream = completion(model=config.llm, messages=messages, stream=True) - for output in stream: - token: str = output["choices"][0]["delta"].get("content") or "" - yield token + tools, tool_choice = _get_tools(messages, config) + # Stream the LLM response, which is either a tool call request or an assistant response. + chunks = [] + clipped_messages = _clip(messages, max_tokens) + if tools and config.llm.startswith("llama-cpp-python"): + # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. + clipped_messages[-1]["content"] += ( + f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + ) + stream = completion( + model=config.llm, + messages=clipped_messages, + tools=tools, + tool_choice=tool_choice, + stream=True, + ) + for chunk in stream: + chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): + yield token + # Check if there are tools to be called. + response = stream_chunk_builder(chunks, messages) + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + if tool_calls: + # Add the tool call request to the message array. + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + # Run the tool calls to retrieve the RAG context and append the output to the message array. + messages.extend(_run_tools(tool_calls, config)) + # Stream the assistant response. + chunks = [] + stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) + for chunk in stream: + chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): + yield token + # Append the assistant response to the message array. + response = stream_chunk_builder(chunks, messages) + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]: - # Truncate the oldest messages so we don't hit the context limit. + # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) - cum_tokens = np.cumsum([len(message.get("content", "")) // 3 for message in messages][::-1]) - messages = messages[-np.searchsorted(cum_tokens, max_tokens) :] - # Asynchronously stream the LLM response. - async_stream = await acompletion(model=config.llm, messages=messages, stream=True) - async for output in async_stream: - token: str = output["choices"][0]["delta"].get("content") or "" - yield token + tools, tool_choice = _get_tools(messages, config) + # Asynchronously stream the LLM response, which is either a tool call or an assistant response. + chunks = [] + clipped_messages = _clip(messages, max_tokens) + if tools and config.llm.startswith("llama-cpp-python"): + # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. + clipped_messages[-1]["content"] += ( + f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + ) + async_stream = await acompletion( + model=config.llm, + messages=clipped_messages, + tools=tools, + tool_choice=tool_choice, + stream=True, + ) + async for chunk in async_stream: + chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): + yield token + # Check if there are tools to be called. + response = stream_chunk_builder(chunks, messages) + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + if tool_calls: + # Add the tool call requests to the message array. + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + # Run the tool calls to retrieve the RAG context and append the output to the message array. + # TODO: Make this async. + messages.extend(_run_tools(tool_calls, config)) + # Asynchronously stream the assistant response. + chunks = [] + async_stream = await acompletion( + model=config.llm, messages=_clip(messages, max_tokens), stream=True + ) + async for chunk in async_stream: + chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): + yield token + # Append the assistant response to the message array. + response = stream_chunk_builder(chunks, messages) + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] diff --git a/src/raglite/_search.py b/src/raglite/_search.py index c109324..54fec07 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -1,5 +1,6 @@ """Search and retrieve chunks.""" +import contextlib import re import string from collections import defaultdict @@ -8,7 +9,7 @@ from typing import cast import numpy as np -from langdetect import detect +from langdetect import LangDetectException, detect from sqlalchemy.engine import make_url from sqlalchemy.orm import joinedload from sqlmodel import Session, and_, col, or_, select, text @@ -212,8 +213,9 @@ def rerank_chunks( # Select the reranker. if isinstance(config.reranker, Sequence): # Detect the languages of the chunks and queries. - langs = {detect(str(chunk)) for chunk in chunks} - langs.add(detect(query)) + with contextlib.suppress(LangDetectException): + langs = {detect(str(chunk)) for chunk in chunks} + langs.add(detect(query)) # If all chunks and the query are in the same language, use a language-specific reranker. rerankers = dict(config.reranker) if len(langs) == 1 and (lang := next(iter(langs))) in rerankers: diff --git a/tests/conftest.py b/tests/conftest.py index ef41889..1465138 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from pathlib import Path import pytest +from llama_cpp import llama_supports_gpu_offload from sqlalchemy import create_engine, text from raglite import RAGLiteConfig, insert_document @@ -23,6 +24,11 @@ def is_postgres_running() -> bool: return False +def is_accelerator_available() -> bool: + """Check if an accelerator is available.""" + return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 + + def is_openai_available() -> bool: """Check if an OpenAI API key is set.""" return bool(os.environ.get("OPENAI_API_KEY")) @@ -69,24 +75,57 @@ def database(request: pytest.FixtureRequest) -> str: scope="session", params=[ pytest.param( - "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. - id="bge_m3", + ( + "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4096", + "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. + ), + id="llama31_8B-bge_m3", + marks=pytest.mark.skipif( + not is_accelerator_available(), reason="No accelerator available" + ), ), pytest.param( - "text-embedding-3-small", - id="openai_text_embedding_3_small", + ( + "gpt-4o-mini", + "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. + ), + id="gpt_4o_mini-bge_m3", + marks=pytest.mark.skipif( + not is_openai_available() or is_accelerator_available(), + reason="OpenAI API key is not set" + if not is_openai_available() + else "Local LLM available", + ), + ), + pytest.param( + ("gpt-4o-mini", "text-embedding-3-small"), + id="gpt_4o_mini-text_embedding_3_small", marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"), ), ], ) -def embedder(request: pytest.FixtureRequest) -> str: - """Get an embedder model URL to test RAGLite with.""" - embedder: str = request.param +def llm_embedder(request: pytest.FixtureRequest) -> str: + """Get an LLM and embedder pair to test RAGLite with.""" + llm_embedder: str = request.param + return llm_embedder + + +@pytest.fixture(scope="session") +def llm(llm_embedder: tuple[str, str]) -> str: + """Get an LLM to test RAGLite with.""" + llm, _ = llm_embedder + return llm + + +@pytest.fixture(scope="session") +def embedder(llm_embedder: tuple[str, str]) -> str: + """Get an embedder to test RAGLite with.""" + _, embedder = llm_embedder return embedder @pytest.fixture(scope="session") -def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig: +def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig: """Create a lightweight in-memory config for testing SQLite and PostgreSQL.""" # Select the database based on the embedder. variant = "local" if embedder.startswith("llama-cpp-python") else "remote" @@ -95,7 +134,7 @@ def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig: elif "sqlite" in database: database = database.replace(".sqlite", f"_{variant}.sqlite") # Create a RAGLite config for the given database and embedder. - db_config = RAGLiteConfig(db_url=database, embedder=embedder) + db_config = RAGLiteConfig(db_url=database, llm=llm, embedder=embedder) # Insert a document and update the index. doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. insert_document(doc_path, config=db_config) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7643bcf..7a391ba 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,9 +1,6 @@ """Test RAGLite's RAG functionality.""" -import os - -import pytest -from llama_cpp import llama_supports_gpu_offload +import json from raglite import ( RAGLiteConfig, @@ -13,21 +10,53 @@ from raglite._rag import rag -def is_accelerator_available() -> bool: - """Check if an accelerator is available.""" - return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 - - -@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available") -def test_rag(raglite_test_config: RAGLiteConfig) -> None: - """Test Retrieval-Augmented Generation.""" - # Answer a question with RAG. - user_prompt = "What does it mean for two events to be simultaneous?" +def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: + """Test Retrieval-Augmented Generation with manual retrieval.""" + # Answer a question with manual RAG. + user_prompt = "What is special relativity's definition of 'simultaneous events'?" chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config) messages = [create_rag_instruction(user_prompt, context=chunk_spans)] stream = rag(messages, config=raglite_test_config) answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + assert "event" in answer.lower() + # Verify that no RAG context was retrieved through tool use. + assert [message["role"] for message in messages] == ["user", "assistant"] + + +def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: + """Test Retrieval-Augmented Generation with automatic retrieval.""" + # Answer a question that requires RAG. + user_prompt = "What is special relativity's definition of 'simultaneous events'?" + messages = [{"role": "user", "content": user_prompt}] + stream = rag(messages, config=raglite_test_config) + answer = "" for update in stream: assert isinstance(update, str) answer += update assert "simultaneous" in answer.lower() + # Verify that RAG context was retrieved automatically. + assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] + + +def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: + """Test Retrieval-Augmented Generation with automatic retrieval.""" + # Answer a question that does not require RAG. + user_prompt = "Is 7 a prime number?" + messages = [{"role": "user", "content": user_prompt}] + stream = rag(messages, config=raglite_test_config) + answer = "" + for update in stream: + assert isinstance(update, str) + answer += update + assert "yes" in answer.lower() + # Verify that no RAG context was retrieved. + if raglite_test_config.llm.startswith("llama-cpp-python"): + # Llama.cpp does not support streaming tool_choice="auto" yet, so instead we verify that the + # LLM indicates that the tool call request may be skipped by checking that content is empty. + assert [msg["role"] for msg in messages] == ["user", "assistant", "tool", "assistant"] + assert not json.loads(messages[-2]["content"]) + else: + assert [msg["role"] for msg in messages] == ["user", "assistant"] From aa5cc8035cee1b6a648a5385b1689522d016a1d8 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 18:09:51 +0100 Subject: [PATCH 02/11] test: improve config consistency --- tests/test_extract.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_extract.py b/tests/test_extract.py index 33ef6e0..2bdf07b 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -9,18 +9,6 @@ from raglite._extract import extract_with_llm -@pytest.fixture( - params=[ - pytest.param(RAGLiteConfig().llm, id="llama_cpp_python"), - pytest.param("gpt-4o-mini", id="openai"), - ] -) -def llm(request: pytest.FixtureRequest) -> str: - """Get an LLM to test RAGLite with.""" - llm: str = request.param - return llm - - @pytest.mark.parametrize( "strict", [pytest.param(False, id="strict=False"), pytest.param(True, id="strict=True")] ) From 93a75b993cf4672c0fc60a7f7ddf3b27b231af65 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 18:10:17 +0100 Subject: [PATCH 03/11] feat: add tool use to Chainlit --- src/raglite/_chainlit.py | 54 +++++++++++++++------------------------- 1 file changed, 20 insertions(+), 34 deletions(-) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 1f3eeeb..f3e15d9 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -1,27 +1,17 @@ """Chainlit frontend for RAGLite.""" +import json import os from pathlib import Path import chainlit as cl from chainlit.input_widget import Switch, TextInput -from raglite import ( - RAGLiteConfig, - async_rag, - create_rag_instruction, - hybrid_search, - insert_document, - rerank_chunks, - retrieve_chunk_spans, - retrieve_chunks, -) +from raglite import RAGLiteConfig, async_rag, hybrid_search, insert_document, rerank_chunks from raglite._markdown import document_to_markdown async_insert_document = cl.make_async(insert_document) async_hybrid_search = cl.make_async(hybrid_search) -async_retrieve_chunks = cl.make_async(retrieve_chunks) -async_retrieve_chunk_spans = cl.make_async(retrieve_chunk_spans) async_rerank_chunks = cl.make_async(rerank_chunks) @@ -93,31 +83,27 @@ async def handle_message(user_message: cl.Message) -> None: for i, attachment in enumerate(inline_attachments) ) + f"\n\n{user_message.content}" - ) - # Search for relevant contexts for RAG. - async with cl.Step(name="search", type="retrieval") as step: - step.input = user_message.content - chunk_ids, _ = await async_hybrid_search(query=user_prompt, num_results=10, config=config) - chunks = await async_retrieve_chunks(chunk_ids=chunk_ids, config=config) - step.output = chunks - step.elements = [ # Show the top chunks inline. - cl.Text(content=str(chunk), display="inline") for chunk in chunks[:5] - ] - await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. - # Rerank the chunks and group them into chunk spans. - async with cl.Step(name="rerank", type="rerank") as step: - step.input = chunks - chunks = await async_rerank_chunks(query=user_prompt, chunk_ids=chunks, config=config) - chunk_spans = await async_retrieve_chunk_spans(chunks[:5], config=config) - step.output = chunk_spans - step.elements = [ # Show the top chunk spans inline. - cl.Text(content=str(chunk_span), display="inline") for chunk_span in chunk_spans - ] - await step.update() # TODO: Workaround for https://github.com/Chainlit/chainlit/issues/602. + ).strip() # Stream the LLM response. assistant_message = cl.Message(content="") messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] - messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) + messages.append({"role": "user", "content": user_prompt}) async for token in async_rag(messages, config=config): await assistant_message.stream_token(token) + # Append RAG sources if any. + if messages[-2]["role"] == "tool": + rag_context = json.loads(messages[-2]["content"]) + rag_sources: dict[str, list[str]] = {} + for document in rag_context["documents"]: + rag_sources.setdefault(document["source"], []) + rag_sources[document["source"]].append( + document["span"]["headings"] + "\n" + document["span"]["content"] + ) + assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks. + f"[{i + 1}]" for i in range(len(rag_sources)) + ) + assistant_message.elements = [ # Markdown content is rendered in sidebar. + cl.Text(name=f"[{i + 1}]", content="\n\n---\n\n".join(content), display="side") # type: ignore[misc] + for i, (_, content) in enumerate(rag_sources.items()) + ] await assistant_message.update() # type: ignore[no-untyped-call] From 4728d2808a4b3d6c4f692780a036612eb25bb0e7 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 18:28:59 +0100 Subject: [PATCH 04/11] test: increase LLM coverage --- tests/conftest.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1465138..ee92f87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -76,25 +76,19 @@ def database(request: pytest.FixtureRequest) -> str: params=[ pytest.param( ( - "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4096", + "llama-cpp-python/bartowski/Llama-3.2-3B-Instruct-GGUF/*Q4_K_M.gguf@4096", "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. ), - id="llama31_8B-bge_m3", - marks=pytest.mark.skipif( - not is_accelerator_available(), reason="No accelerator available" - ), + id="llama32_3B-bge_m3", ), pytest.param( ( - "gpt-4o-mini", + "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4096", "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. ), - id="gpt_4o_mini-bge_m3", + id="llama31_8B-bge_m3", marks=pytest.mark.skipif( - not is_openai_available() or is_accelerator_available(), - reason="OpenAI API key is not set" - if not is_openai_available() - else "Local LLM available", + not is_accelerator_available(), reason="No accelerator available" ), ), pytest.param( From 888cb439344cc8ea099cbc1ace7891bf764b98bc Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 19:16:32 +0100 Subject: [PATCH 05/11] fix: make tool use more robust --- src/raglite/_rag.py | 15 ++++++++++----- tests/test_extract.py | 2 +- tests/test_rag.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index b371912..2577d5e 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -104,7 +104,7 @@ def _get_tools( { "skip": { "type": "boolean", - "description": "True if a satisfactory answer can be provided without the knowledge base, false otherwise.", + "description": "True if a satisfactory answer can be provided without searching the knowledge base, false otherwise.", } } if llm_provider == "llama-cpp-python" @@ -125,11 +125,10 @@ def _get_tools( "type": ["string", "null"], "description": "\n".join( # noqa: FLY002 [ - "The query string to search the knowledge base with.", + "The query string to search the knowledge base with, or `null` if `skip` is `true`.", "The query string MUST satisfy ALL of the following criteria:" "- The query string MUST be a precise question in the user's language.", "- The query string MUST resolve all pronouns to explicit nouns from the conversation history.", - "- The query string MUST be `null` if `skip` is `true`.", ] ), }, @@ -195,7 +194,10 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st if tools and config.llm.startswith("llama-cpp-python"): # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. clipped_messages[-1]["content"] += ( - f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + "\n\n\n" + f"Available tools:\n```\n{json.dumps(tools)}\n```\n" + "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory answer without searching the knowledge base.\n" + "" ) stream = completion( model=config.llm, @@ -238,7 +240,10 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> if tools and config.llm.startswith("llama-cpp-python"): # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. clipped_messages[-1]["content"] += ( - f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + "\n\n\n" + f"Available tools:\n```\n{json.dumps(tools)}\n```\n" + "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory response without searching the knowledge base.\n" + "" ) async_stream = await acompletion( model=config.llm, diff --git a/tests/test_extract.py b/tests/test_extract.py index 2bdf07b..9cf8dd1 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -29,7 +29,7 @@ class LoginResponse(BaseModel): # Extract structured data. username, password = "cypher", "steak" login_response = extract_with_llm( - LoginResponse, f"{username} // {password}", strict=strict, config=config + LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config ) # Validate the response. assert isinstance(login_response, LoginResponse) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7a391ba..edd60dc 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -44,7 +44,7 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that does not require RAG. - user_prompt = "Is 7 a prime number?" + user_prompt = "Yes or no: is 'Veni, vidi, vici' a Latin phrase?" messages = [{"role": "user", "content": user_prompt}] stream = rag(messages, config=raglite_test_config) answer = "" From f440ae49776d22c2d638d6959755b602cce7572b Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 15:52:00 +0100 Subject: [PATCH 06/11] fix: make RAG with SLMs more robust --- src/raglite/_chainlit.py | 3 +-- src/raglite/_rag.py | 26 ++++++++++++-------------- tests/conftest.py | 10 ---------- tests/test_rag.py | 9 +++++---- 4 files changed, 18 insertions(+), 30 deletions(-) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index f3e15d9..3860599 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -91,8 +91,7 @@ async def handle_message(user_message: cl.Message) -> None: async for token in async_rag(messages, config=config): await assistant_message.stream_token(token) # Append RAG sources if any. - if messages[-2]["role"] == "tool": - rag_context = json.loads(messages[-2]["content"]) + if messages[-2]["role"] == "tool" and (rag_context := json.loads(messages[-2]["content"])): rag_sources: dict[str, list[str]] = {} for document in rag_context["documents"]: rag_sources.setdefault(document["source"], []) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 2577d5e..b87c39a 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -102,9 +102,9 @@ def _get_tools( # to use a tool, but allows it to skip the search. auto_tool_use_workaround = ( { - "skip": { + "expert": { "type": "boolean", - "description": "True if a satisfactory answer can be provided without searching the knowledge base, false otherwise.", + "description": "The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.", } } if llm_provider == "llama-cpp-python" @@ -116,20 +116,18 @@ def _get_tools( "type": "function", "function": { "name": "search_knowledge_base", - "description": "Search the knowledge base. Note: only use this tool if not enough information is available to provide an answer.", + "description": "Search the knowledge base. IMPORTANT: Only use this tool if a well-rounded non-expert would need to look up information to answer the question.", "parameters": { "type": "object", "properties": { **auto_tool_use_workaround, "query": { - "type": ["string", "null"], - "description": "\n".join( # noqa: FLY002 - [ - "The query string to search the knowledge base with, or `null` if `skip` is `true`.", - "The query string MUST satisfy ALL of the following criteria:" - "- The query string MUST be a precise question in the user's language.", - "- The query string MUST resolve all pronouns to explicit nouns from the conversation history.", - ] + "type": "string", + "description": ( + "The `query` string to search the knowledge base with.\n" + "The `query` string MUST satisfy ALL of the following criteria:\n" + "- The `query` string MUST be a precise question in the user's language.\n" + "- The `query` string MUST resolve all pronouns to explicit nouns from the conversation history." ), }, }, @@ -163,7 +161,7 @@ def _run_tools( if tool_call.function.name == "search_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config - skip = kwargs.pop("skip", False) + skip = not kwargs.pop("expert", True) tool_messages.append( { "role": "tool", @@ -196,7 +194,7 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st clipped_messages[-1]["content"] += ( "\n\n\n" f"Available tools:\n```\n{json.dumps(tools)}\n```\n" - "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory answer without searching the knowledge base.\n" + "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n" "" ) stream = completion( @@ -242,7 +240,7 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> clipped_messages[-1]["content"] += ( "\n\n\n" f"Available tools:\n```\n{json.dumps(tools)}\n```\n" - "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory response without searching the knowledge base.\n" + "IMPORTANT: The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.\n" "" ) async_stream = await acompletion( diff --git a/tests/conftest.py b/tests/conftest.py index ee92f87..943250c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,16 +81,6 @@ def database(request: pytest.FixtureRequest) -> str: ), id="llama32_3B-bge_m3", ), - pytest.param( - ( - "llama-cpp-python/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/*Q4_K_M.gguf@4096", - "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf@1024", # More context degrades performance. - ), - id="llama31_8B-bge_m3", - marks=pytest.mark.skipif( - not is_accelerator_available(), reason="No accelerator available" - ), - ), pytest.param( ("gpt-4o-mini", "text-embedding-3-small"), id="gpt_4o_mini-text_embedding_3_small", diff --git a/tests/test_rag.py b/tests/test_rag.py index edd60dc..39ffe12 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -13,7 +13,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with manual retrieval.""" # Answer a question with manual RAG. - user_prompt = "What is special relativity's definition of 'simultaneous events'?" + user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" chunk_spans = retrieve_rag_context(query=user_prompt, config=raglite_test_config) messages = [create_rag_instruction(user_prompt, context=chunk_spans)] stream = rag(messages, config=raglite_test_config) @@ -29,22 +29,23 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that requires RAG. - user_prompt = "What is special relativity's definition of 'simultaneous events'?" + user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" messages = [{"role": "user", "content": user_prompt}] stream = rag(messages, config=raglite_test_config) answer = "" for update in stream: assert isinstance(update, str) answer += update - assert "simultaneous" in answer.lower() + assert "event" in answer.lower() # Verify that RAG context was retrieved automatically. assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] + assert json.loads(messages[-2]["content"]) def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that does not require RAG. - user_prompt = "Yes or no: is 'Veni, vidi, vici' a Latin phrase?" + user_prompt = "Is 7 a prime number? Answer with Yes or No only." messages = [{"role": "user", "content": user_prompt}] stream = rag(messages, config=raglite_test_config) answer = "" From 4f5f038f02c7c00819ac8bf5bc8d81a89eb65a65 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 16:34:47 +0100 Subject: [PATCH 07/11] fix: improve LiteLLM usage --- src/raglite/_litellm.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 31bf279..9f1d0ea 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -22,6 +22,7 @@ get_model_info, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import custom_llm_setup from llama_cpp import ( # type: ignore[attr-defined] ChatCompletionRequestMessage, CreateChatCompletionResponse, @@ -33,6 +34,7 @@ from raglite._config import RAGLiteConfig # Reduce the logging level for LiteLLM, flashrank, and httpx. +litellm.suppress_debug_info = True os.environ["LITELLM_LOG"] = "WARNING" logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) @@ -125,24 +127,23 @@ def llm(model: str, **kwargs: Any) -> Llama: # Enable caching. llm.set_cache(LlamaRAMCache()) # Register the model info with LiteLLM. - litellm.register_model( # type: ignore[attr-defined] - { - model: { - "max_tokens": llm.n_ctx(), - "max_input_tokens": llm.n_ctx(), - "max_output_tokens": None, - "input_cost_per_token": 0.0, - "output_cost_per_token": 0.0, - "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, - "litellm_provider": "llama-cpp-python", - "mode": "embedding" if kwargs.get("embedding") else "completion", - "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, - "supports_function_calling": True, - "supports_parallel_function_calling": True, - "supports_vision": False, - } + model_info = { + model: { + "max_tokens": llm.n_ctx(), + "max_input_tokens": llm.n_ctx(), + "max_output_tokens": None, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "output_vector_size": llm.n_embd() if kwargs.get("embedding") else None, + "litellm_provider": "llama-cpp-python", + "mode": "embedding" if kwargs.get("embedding") else "completion", + "supported_openai_params": LlamaCppPythonLLM.supported_openai_params, + "supports_function_calling": True, + "supports_parallel_function_calling": True, + "supports_vision": False, } - ) + } + litellm.register_model(model_info) # type: ignore[attr-defined] return llm def _translate_openai_params(self, optional_params: dict[str, Any]) -> dict[str, Any]: @@ -307,7 +308,7 @@ async def astreaming( # type: ignore[misc,override] # noqa: PLR0913 litellm.custom_provider_map.append( {"provider": "llama-cpp-python", "custom_handler": LlamaCppPythonLLM()} ) - litellm.suppress_debug_info = True + custom_llm_setup() # type: ignore[no-untyped-call] @cache From 9ae3d8d6ab0c82f0f057ac9a1cdcd5e03f5caef3 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 19:04:55 +0100 Subject: [PATCH 08/11] fix: fix registering of llama.cpp model_info --- src/raglite/_extract.py | 5 ++--- src/raglite/_litellm.py | 8 +++----- src/raglite/_rag.py | 5 ++--- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/raglite/_extract.py b/src/raglite/_extract.py index bd85d47..c902e68 100644 --- a/src/raglite/_extract.py +++ b/src/raglite/_extract.py @@ -34,13 +34,12 @@ class MyNameResponse(BaseModel): # Load the default config if not provided. config = config or RAGLiteConfig() # Check if the LLM supports the response format. - llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None llm_supports_response_format = "response_format" in ( - get_supported_openai_params(model=config.llm, custom_llm_provider=llm_provider) or [] + get_supported_openai_params(model=config.llm) or [] ) # Update the system prompt with the JSON schema of the return type to help the LLM. system_prompt = getattr(return_type, "system_prompt", "").strip() - if not llm_supports_response_format or llm_provider == "llama-cpp-python": + if not llm_supports_response_format or config.llm.startswith("llama-cpp-python"): system_prompt += f"\n\nFormat your response according to this JSON schema:\n{return_type.model_json_schema()!s}" # Constrain the reponse format to the JSON schema if it's supported by the LLM [1]. Strict mode # is disabled by default because it only supports a subset of JSON schema features [2]. diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index 9f1d0ea..e0bab3b 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -128,7 +128,7 @@ def llm(model: str, **kwargs: Any) -> Llama: llm.set_cache(LlamaRAMCache()) # Register the model info with LiteLLM. model_info = { - model: { + repo_id_filename: { "max_tokens": llm.n_ctx(), "max_input_tokens": llm.n_ctx(), "max_output_tokens": None, @@ -319,8 +319,7 @@ def get_context_size(config: RAGLiteConfig, *, fallback: int = 2048) -> int: if config.llm.startswith("llama-cpp-python"): _ = LlamaCppPythonLLM.llm(config.llm) # Attempt to read the context size from LiteLLM's model info. - llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None - model_info = get_model_info(config.llm, custom_llm_provider=llm_provider) + model_info = get_model_info(config.llm) max_tokens = model_info.get("max_tokens") if isinstance(max_tokens, int) and max_tokens > 0: return max_tokens @@ -343,8 +342,7 @@ def get_embedding_dim(config: RAGLiteConfig, *, fallback: bool = True) -> int: if config.embedder.startswith("llama-cpp-python"): _ = LlamaCppPythonLLM.llm(config.embedder, embedding=True) # Attempt to read the embedding dimension from LiteLLM's model info. - llm_provider = "llama-cpp-python" if config.embedder.startswith("llama-cpp") else None - model_info = get_model_info(config.embedder, custom_llm_provider=llm_provider) + model_info = get_model_info(config.embedder) embedding_dim = model_info.get("output_vector_size") if isinstance(embedding_dim, int) and embedding_dim > 0: return embedding_dim diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index b87c39a..9abd59f 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -92,8 +92,7 @@ def _get_tools( # Check if messages already contain RAG context or if the LLM supports tool use. final_message = messages[-1].get("content", "") messages_contain_rag_context = any(s in final_message for s in ("", "from_chunk_id")) - llm_provider = "llama-cpp-python" if config.llm.startswith("llama-cpp") else None - llm_supports_function_calling = supports_function_calling(config.llm, llm_provider) + llm_supports_function_calling = supports_function_calling(config.llm) if not messages_contain_rag_context and not llm_supports_function_calling: error_message = "You must either explicitly provide RAG context in the last message, or use an LLM that supports function calling." raise ValueError(error_message) @@ -107,7 +106,7 @@ def _get_tools( "description": "The `expert` boolean MUST be true if the question requires domain-specific or expert-level knowledge to answer, and false otherwise.", } } - if llm_provider == "llama-cpp-python" + if config.llm.startswith("llama-cpp-python") else {} ) tools: list[dict[str, Any]] | None = ( From a011621151e0bff8787cf11f509c0c1296e7c43c Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 19:46:54 +0100 Subject: [PATCH 09/11] feat: add an on_retrieval callback --- README.md | 13 ++++++------- src/raglite/_rag.py | 29 ++++++++++++++++++++++------- tests/test_rag.py | 10 ++++++++-- 3 files changed, 36 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 3d2bada..02b2658 100644 --- a/README.md +++ b/README.md @@ -161,7 +161,7 @@ insert_document(Path("Special Relativity.pdf"), config=my_config) #### 3.1 Minimal RAG pipeline -Now you can run a minimal RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking it as a tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans. The retrieval results are appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the response: +Now you can run a minimal RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking it as a tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are received by the `on_retrieval` callback, and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the response: ```python from raglite import rag @@ -176,19 +176,18 @@ messages.append({ # Let the LLM decide whether to search the database by providing a search method as a tool to the LLM. # If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history. # Finally, LLM response is streamed and appended to the message history. -stream = rag(messages, config=my_config) +chunk_spans = [] +stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config) for update in stream: print(update, end="") -# Access the RAG context appended to the message history: -import json - -context = [json.loads(message["content"]) for message in messages if message["role"] == "tool"] +# Access the documents referenced in the RAG context: +documents = [chunk_span.document for chunk_span in chunk_spans] ``` #### 3.2 Basic RAG pipeline -If you want control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: +If you want control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: ```python from raglite import create_rag_instruction, rag, retrieve_rag_context diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 9abd59f..da52e23 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,7 +1,7 @@ """Retrieval-augmented generation.""" import json -from collections.abc import AsyncIterator, Iterator +from collections.abc import AsyncIterator, Callable, Iterator from typing import Any import numpy as np @@ -152,7 +152,9 @@ def _get_tools( def _run_tools( - tool_calls: list[ChatCompletionMessageToolCall], config: RAGLiteConfig + tool_calls: list[ChatCompletionMessageToolCall], + on_retrieval: Callable[[list[ChunkSpan]], None] | None, + config: RAGLiteConfig, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" tool_messages: list[dict[str, Any]] = [] @@ -161,13 +163,14 @@ def _run_tools( kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config skip = not kwargs.pop("expert", True) + chunk_spans = retrieve_rag_context(**kwargs) if not skip and kwargs["query"] else None tool_messages.append( { "role": "tool", "content": '{{"documents": [{elements}]}}'.format( elements=", ".join( chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(retrieve_rag_context(**kwargs)) + for i, chunk_span in enumerate(chunk_spans) # type: ignore[arg-type] ) ) if not skip and kwargs["query"] @@ -175,13 +178,20 @@ def _run_tools( "tool_call_id": tool_call.id, } ) + if chunk_spans and callable(on_retrieval): + on_retrieval(chunk_spans) else: error_message = f"Unknown function `{tool_call.function.name}`." raise ValueError(error_message) return tool_messages -def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[str]: +def rag( + messages: list[dict[str, str]], + *, + on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + config: RAGLiteConfig, +) -> Iterator[str]: # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) tools, tool_choice = _get_tools(messages, config) @@ -214,7 +224,7 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st # Add the tool call request to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. - messages.extend(_run_tools(tool_calls, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config)) # Stream the assistant response. chunks = [] stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) @@ -227,7 +237,12 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] -async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> AsyncIterator[str]: +async def async_rag( + messages: list[dict[str, str]], + *, + on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + config: RAGLiteConfig, +) -> AsyncIterator[str]: # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) tools, tool_choice = _get_tools(messages, config) @@ -261,7 +276,7 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. - messages.extend(_run_tools(tool_calls, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config)) # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( diff --git a/tests/test_rag.py b/tests/test_rag.py index 39ffe12..ff0e0ab 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -7,6 +7,7 @@ create_rag_instruction, retrieve_rag_context, ) +from raglite._database import ChunkSpan from raglite._rag import rag @@ -31,7 +32,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Answer a question that requires RAG. user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" messages = [{"role": "user", "content": user_prompt}] - stream = rag(messages, config=raglite_test_config) + chunk_spans = [] + stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) answer = "" for update in stream: assert isinstance(update, str) @@ -40,6 +42,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Verify that RAG context was retrieved automatically. assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] assert json.loads(messages[-2]["content"]) + assert chunk_spans + assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: @@ -47,7 +51,8 @@ def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Answer a question that does not require RAG. user_prompt = "Is 7 a prime number? Answer with Yes or No only." messages = [{"role": "user", "content": user_prompt}] - stream = rag(messages, config=raglite_test_config) + chunk_spans = [] + stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=raglite_test_config) answer = "" for update in stream: assert isinstance(update, str) @@ -61,3 +66,4 @@ def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: assert not json.loads(messages[-2]["content"]) else: assert [msg["role"] for msg in messages] == ["user", "assistant"] + assert not chunk_spans From a446373e81f5e5230416ff0e8fdce2d02eefbf00 Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 20:11:12 +0100 Subject: [PATCH 10/11] docs: improve README section on dynamic routing --- README.md | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 02b2658..fbe2a1b 100644 --- a/README.md +++ b/README.md @@ -159,9 +159,9 @@ insert_document(Path("Special Relativity.pdf"), config=my_config) ### 3. Searching and Retrieval-Augmented Generation (RAG) -#### 3.1 Minimal RAG pipeline +#### 3.1 Dynamically routed RAG -Now you can run a minimal RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking it as a tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are received by the `on_retrieval` callback, and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the response: +Now you can run a dynamically routed RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking a retrieval tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans (each of which is a list of consecutive chunks). The retrieval results are sent to the `on_retrieval` callback and are also appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the assistant response: ```python from raglite import rag @@ -173,9 +173,9 @@ messages.append({ "content": "How is intelligence measured?" }) -# Let the LLM decide whether to search the database by providing a search method as a tool to the LLM. +# Let the LLM decide whether to search the database by providing a retrieval tool to the LLM. # If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history. -# Finally, LLM response is streamed and appended to the message history. +# Finally, assistant response is streamed and appended to the message history. chunk_spans = [] stream = rag(messages, on_retrieval=lambda x: chunk_spans.extend(x), config=my_config) for update in stream: @@ -185,9 +185,9 @@ for update in stream: documents = [chunk_span.document for chunk_span in chunk_spans] ``` -#### 3.2 Basic RAG pipeline +#### 3.2 Programmable RAG -If you want control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: +If you need manual control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response: ```python from raglite import create_rag_instruction, rag, retrieve_rag_context @@ -200,7 +200,7 @@ chunk_spans = retrieve_rag_context(query=user_prompt, num_chunks=5, config=my_co messages = [] # Or start with an existing message history. messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) -# Stream the RAG response: +# Stream the RAG response and append it to the message history: stream = rag(messages, config=my_config) for update in stream: print(update, end="") @@ -209,12 +209,10 @@ for update in stream: documents = [chunk_span.document for chunk_span in chunk_spans] ``` -#### 3.3 Advanced RAG pipeline - > [!TIP] > 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks. -In addition to the basic RAG pipeline, RAGLite also offers more advanced control over the pipeline. A full pipeline consists of several steps: +RAGLite also offers more advanced control over the individual steps of a full RAG pipeline: 1. Searching for relevant chunks with keyword, vector, or hybrid search 2. Retrieving the chunks from the database @@ -255,7 +253,7 @@ from raglite import create_rag_instruction messages = [] # Or start with an existing message history. messages.append(create_rag_instruction(user_prompt=user_prompt, context=chunk_spans)) -# Stream the RAG response: +# Stream the RAG response and append it to the message history: from raglite import rag stream = rag(messages, config=my_config) From 82def46981f9a729ca2d377ed3a98120910930ac Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Tue, 10 Dec 2024 20:29:06 +0100 Subject: [PATCH 11/11] fix: update Chainlit integration to use the new callback --- src/raglite/_chainlit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/raglite/_chainlit.py b/src/raglite/_chainlit.py index 3860599..7dd53ef 100644 --- a/src/raglite/_chainlit.py +++ b/src/raglite/_chainlit.py @@ -1,6 +1,5 @@ """Chainlit frontend for RAGLite.""" -import json import os from pathlib import Path @@ -86,18 +85,19 @@ async def handle_message(user_message: cl.Message) -> None: ).strip() # Stream the LLM response. assistant_message = cl.Message(content="") + chunk_spans = [] messages: list[dict[str, str]] = cl.chat_context.to_openai()[:-1] # type: ignore[no-untyped-call] messages.append({"role": "user", "content": user_prompt}) - async for token in async_rag(messages, config=config): + async for token in async_rag( + messages, on_retrieval=lambda x: chunk_spans.extend(x), config=config + ): await assistant_message.stream_token(token) - # Append RAG sources if any. - if messages[-2]["role"] == "tool" and (rag_context := json.loads(messages[-2]["content"])): + # Append RAG sources, if any. + if chunk_spans: rag_sources: dict[str, list[str]] = {} - for document in rag_context["documents"]: - rag_sources.setdefault(document["source"], []) - rag_sources[document["source"]].append( - document["span"]["headings"] + "\n" + document["span"]["content"] - ) + for chunk_span in chunk_spans: + rag_sources.setdefault(chunk_span.document.id, []) + rag_sources[chunk_span.document.id].append(str(chunk_span)) assistant_message.content += "\n\nSources: " + ", ".join( # Rendered as hyperlinks. f"[{i + 1}]" for i in range(len(rag_sources)) )