diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py
index b371912..2577d5e 100644
--- a/src/raglite/_rag.py
+++ b/src/raglite/_rag.py
@@ -104,7 +104,7 @@ def _get_tools(
{
"skip": {
"type": "boolean",
- "description": "True if a satisfactory answer can be provided without the knowledge base, false otherwise.",
+ "description": "True if a satisfactory answer can be provided without searching the knowledge base, false otherwise.",
}
}
if llm_provider == "llama-cpp-python"
@@ -125,11 +125,10 @@ def _get_tools(
"type": ["string", "null"],
"description": "\n".join( # noqa: FLY002
[
- "The query string to search the knowledge base with.",
+ "The query string to search the knowledge base with, or `null` if `skip` is `true`.",
"The query string MUST satisfy ALL of the following criteria:"
"- The query string MUST be a precise question in the user's language.",
"- The query string MUST resolve all pronouns to explicit nouns from the conversation history.",
- "- The query string MUST be `null` if `skip` is `true`.",
]
),
},
@@ -195,7 +194,10 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st
if tools and config.llm.startswith("llama-cpp-python"):
# Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
clipped_messages[-1]["content"] += (
- f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}"
+ "\n\n\n"
+ f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
+ "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory answer without searching the knowledge base.\n"
+ ""
)
stream = completion(
model=config.llm,
@@ -238,7 +240,10 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) ->
if tools and config.llm.startswith("llama-cpp-python"):
# Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
clipped_messages[-1]["content"] += (
- f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}"
+ "\n\n\n"
+ f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
+ "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory response without searching the knowledge base.\n"
+ ""
)
async_stream = await acompletion(
model=config.llm,
diff --git a/tests/test_extract.py b/tests/test_extract.py
index 2bdf07b..9cf8dd1 100644
--- a/tests/test_extract.py
+++ b/tests/test_extract.py
@@ -29,7 +29,7 @@ class LoginResponse(BaseModel):
# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(
- LoginResponse, f"{username} // {password}", strict=strict, config=config
+ LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
diff --git a/tests/test_rag.py b/tests/test_rag.py
index 7a391ba..edd60dc 100644
--- a/tests/test_rag.py
+++ b/tests/test_rag.py
@@ -44,7 +44,7 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation with automatic retrieval."""
# Answer a question that does not require RAG.
- user_prompt = "Is 7 a prime number?"
+ user_prompt = "Yes or no: is 'Veni, vidi, vici' a Latin phrase?"
messages = [{"role": "user", "content": user_prompt}]
stream = rag(messages, config=raglite_test_config)
answer = ""