Skip to content

Commit

Permalink
fix: make tool use more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber committed Dec 9, 2024
1 parent 4728d28 commit 888cb43
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
15 changes: 10 additions & 5 deletions src/raglite/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _get_tools(
{
"skip": {
"type": "boolean",
"description": "True if a satisfactory answer can be provided without the knowledge base, false otherwise.",
"description": "True if a satisfactory answer can be provided without searching the knowledge base, false otherwise.",
}
}
if llm_provider == "llama-cpp-python"
Expand All @@ -125,11 +125,10 @@ def _get_tools(
"type": ["string", "null"],
"description": "\n".join( # noqa: FLY002
[
"The query string to search the knowledge base with.",
"The query string to search the knowledge base with, or `null` if `skip` is `true`.",
"The query string MUST satisfy ALL of the following criteria:"
"- The query string MUST be a precise question in the user's language.",
"- The query string MUST resolve all pronouns to explicit nouns from the conversation history.",
"- The query string MUST be `null` if `skip` is `true`.",
]
),
},
Expand Down Expand Up @@ -195,7 +194,10 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st
if tools and config.llm.startswith("llama-cpp-python"):
# Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
clipped_messages[-1]["content"] += (
f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}"
"\n\n<tools>\n"
f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
"IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory answer without searching the knowledge base.\n"
"</tools>"
)
stream = completion(
model=config.llm,
Expand Down Expand Up @@ -238,7 +240,10 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) ->
if tools and config.llm.startswith("llama-cpp-python"):
# Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call.
clipped_messages[-1]["content"] += (
f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}"
"\n\n<tools>\n"
f"Available tools:\n```\n{json.dumps(tools)}\n```\n"
"IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory response without searching the knowledge base.\n"
"</tools>"
)
async_stream = await acompletion(
model=config.llm,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class LoginResponse(BaseModel):
# Extract structured data.
username, password = "cypher", "steak"
login_response = extract_with_llm(
LoginResponse, f"{username} // {password}", strict=strict, config=config
LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config
)
# Validate the response.
assert isinstance(login_response, LoginResponse)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None:
def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation with automatic retrieval."""
# Answer a question that does not require RAG.
user_prompt = "Is 7 a prime number?"
user_prompt = "Yes or no: is 'Veni, vidi, vici' a Latin phrase?"
messages = [{"role": "user", "content": user_prompt}]
stream = rag(messages, config=raglite_test_config)
answer = ""
Expand Down

0 comments on commit 888cb43

Please sign in to comment.