Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: let LLM choose whether to retrieve context #62

Merged
merged 11 commits into from
Dec 15, 2024
39 changes: 33 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,36 @@ insert_document(Path("Special Relativity.pdf"), config=my_config)

### 3. Searching and Retrieval-Augmented Generation (RAG)

#### 3.1 Simple RAG pipeline
#### 3.1 Minimal RAG pipeline

Now you can run a simple but powerful RAG pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:
Now you can run a minimal RAG pipeline that consists of adding the user prompt to the message history and streaming the LLM response. Depending on the user prompt, the LLM may choose to retrieve context using RAGLite by invoking it as a tool. If retrieval is necessary, the LLM determines the search query and RAGLite applies hybrid search with reranking to retrieve the most relevant chunk spans. The retrieval results are appended to the message history as a tool output. Finally, the LLM response given the RAG context is streamed and the message history is updated with the response:

```python
from raglite import rag

# Create a user message:
messages = [] # Or start with an existing message history.
messages.append({
"role": "user",
"content": "How is intelligence measured?"
})

# Let the LLM decide whether to search the database by providing a search method as a tool to the LLM.
# If requested, RAGLite then uses hybrid search and reranking to append RAG context to the message history.
# Finally, LLM response is streamed and appended to the message history.
stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the RAG context appended to the message history:
import json

context = [json.loads(message["content"]) for message in messages if message["role"] == "tool"]
```

#### 3.2 Basic RAG pipeline

If you want control over the RAG pipeline, you can run a basic but powerful pipeline that consists of retrieving the most relevant chunk spans (each of which is a list of consecutive chunks) with hybrid search and reranking, converting the user prompt to a RAG instruction and appending it to the message history, and finally generating the RAG response:

```python
from raglite import create_rag_instruction, rag, retrieve_rag_context
Expand All @@ -179,16 +206,16 @@ stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

#### 3.2 Advanced RAG pipeline
#### 3.3 Advanced RAG pipeline

> [!TIP]
> 🥇 Reranking can significantly improve the output quality of a RAG application. To add reranking to your application: first search for a larger set of 20 relevant chunks, then rerank them with a [rerankers](https://github.com/AnswerDotAI/rerankers) reranker, and finally keep the top 5 chunks.

In addition to the simple RAG pipeline, RAGLite also offers more advanced control over the individual steps of the pipeline. A full pipeline consists of several steps:
In addition to the basic RAG pipeline, RAGLite also offers more advanced control over the pipeline. A full pipeline consists of several steps:

1. Searching for relevant chunks with keyword, vector, or hybrid search
2. Retrieving the chunks from the database
Expand Down Expand Up @@ -236,7 +263,7 @@ stream = rag(messages, config=my_config)
for update in stream:
print(update, end="")

# Access the documents cited in the RAG response:
# Access the documents referenced in the RAG context:
documents = [chunk_span.document for chunk_span in chunk_spans]
```

Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ scipy = ">=1.5.0"
spacy = ">=3.7.0,<3.8.0"
# Large Language Models:
huggingface-hub = ">=0.22.0"
litellm = ">=1.47.1"
litellm = ">=1.48.4"
llama-cpp-python = ">=0.3.2"
pydantic = ">=2.7.0"
# Approximate Nearest Neighbors:
Expand Down
31 changes: 27 additions & 4 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,18 +169,41 @@ def to_xml(self, index: int | None = None) -> str:
if not self.chunks:
return ""
index_attribute = f' index="{index}"' if index is not None else ""
xml = "\n".join(
xml_document = "\n".join(
[
f'<document{index_attribute} id="{self.document.id}">',
f"<source>{self.document.url if self.document.url else self.document.filename}</source>",
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[0].id}">',
f"<heading>\n{escape(self.chunks[0].headings.strip())}\n</heading>",
f'<span from_chunk_id="{self.chunks[0].id}" to_chunk_id="{self.chunks[-1].id}">',
f"<headings>\n{escape(self.chunks[0].headings.strip())}\n</headings>",
f"<content>\n{escape(''.join(chunk.body for chunk in self.chunks).strip())}\n</content>",
"</span>",
"</document>",
]
)
return xml
return xml_document

def to_json(self, index: int | None = None) -> str:
"""Convert this chunk span to a JSON representation.

The JSON representation follows Anthropic's best practices [1].

[1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips
"""
if not self.chunks:
return "{}"
index_attribute = {"index": index} if index is not None else {}
json_document = {
**index_attribute,
"id": self.document.id,
"source": self.document.url if self.document.url else self.document.filename,
"span": {
"from_chunk_id": self.chunks[0].id,
"to_chunk_id": self.chunks[-1].id,
"headings": self.chunks[0].headings.strip(),
"content": "".join(chunk.body for chunk in self.chunks).strip(),
},
}
return json.dumps(json_document)

@property
def content(self) -> str:
Expand Down
54 changes: 37 additions & 17 deletions src/raglite/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import httpx
import litellm
from litellm import ( # type: ignore[attr-defined]
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
CustomLLM,
GenericStreamingChunk,
ModelResponse,
Expand Down Expand Up @@ -112,6 +114,8 @@ def llm(model: str, **kwargs: Any) -> Llama:
n_ctx=n_ctx,
n_gpu_layers=-1,
verbose=False,
# Enable function calling.
chat_format="chatml-function-calling",
# Workaround to enable long context embedding models [1].
# [1] https://github.com/abetlen/llama-cpp-python/issues/1762
n_batch=n_ctx if n_ctx > 0 else 1024,
Expand Down Expand Up @@ -218,24 +222,40 @@ def streaming( # noqa: PLR0913
llm.create_chat_completion(messages=messages, **llama_cpp_python_params, stream=True),
)
for chunk in stream:
choices = chunk.get("choices", [])
for choice in choices:
text = choice.get("delta", {}).get("content", None)
finish_reason = choice.get("finish_reason")
litellm_generic_streaming_chunk = GenericStreamingChunk(
text=text, # type: ignore[typeddict-item]
is_finished=bool(finish_reason),
finish_reason=finish_reason, # type: ignore[typeddict-item]
usage=None,
index=choice.get("index"), # type: ignore[typeddict-item]
provider_specific_fields={
"id": chunk.get("id"),
"model": chunk.get("model"),
"created": chunk.get("created"),
"object": chunk.get("object"),
},
choices = chunk.get("choices")
if not choices:
continue
text = choices[0].get("delta", {}).get("content", None)
tool_calls = choices[0].get("delta", {}).get("tool_calls", None)
tool_use = (
ChatCompletionToolCallChunk(
id=tool_calls[0]["id"], # type: ignore[index]
type="function",
function=ChatCompletionToolCallFunctionChunk(
name=tool_calls[0]["function"]["name"], # type: ignore[index]
arguments=tool_calls[0]["function"]["arguments"], # type: ignore[index]
),
index=tool_calls[0]["index"], # type: ignore[index]
)
yield litellm_generic_streaming_chunk
if tool_calls
else None
)
finish_reason = choices[0].get("finish_reason")
litellm_generic_streaming_chunk = GenericStreamingChunk(
text=text, # type: ignore[typeddict-item]
tool_use=tool_use,
is_finished=bool(finish_reason),
finish_reason=finish_reason, # type: ignore[typeddict-item]
usage=None,
index=choices[0].get("index"), # type: ignore[typeddict-item]
provider_specific_fields={
"id": chunk.get("id"),
"model": chunk.get("model"),
"created": chunk.get("created"),
"object": chunk.get("object"),
},
)
yield litellm_generic_streaming_chunk

async def astreaming( # type: ignore[misc,override] # noqa: PLR0913
self,
Expand Down
Loading
Loading