From 888cb439344cc8ea099cbc1ace7891bf764b98bc Mon Sep 17 00:00:00 2001 From: Laurent Sorber Date: Mon, 9 Dec 2024 19:16:32 +0100 Subject: [PATCH] fix: make tool use more robust --- src/raglite/_rag.py | 15 ++++++++++----- tests/test_extract.py | 2 +- tests/test_rag.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index b371912..2577d5e 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -104,7 +104,7 @@ def _get_tools( { "skip": { "type": "boolean", - "description": "True if a satisfactory answer can be provided without the knowledge base, false otherwise.", + "description": "True if a satisfactory answer can be provided without searching the knowledge base, false otherwise.", } } if llm_provider == "llama-cpp-python" @@ -125,11 +125,10 @@ def _get_tools( "type": ["string", "null"], "description": "\n".join( # noqa: FLY002 [ - "The query string to search the knowledge base with.", + "The query string to search the knowledge base with, or `null` if `skip` is `true`.", "The query string MUST satisfy ALL of the following criteria:" "- The query string MUST be a precise question in the user's language.", "- The query string MUST resolve all pronouns to explicit nouns from the conversation history.", - "- The query string MUST be `null` if `skip` is `true`.", ] ), }, @@ -195,7 +194,10 @@ def rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> Iterator[st if tools and config.llm.startswith("llama-cpp-python"): # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. clipped_messages[-1]["content"] += ( - f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + "\n\n\n" + f"Available tools:\n```\n{json.dumps(tools)}\n```\n" + "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory answer without searching the knowledge base.\n" + "" ) stream = completion( model=config.llm, @@ -238,7 +240,10 @@ async def async_rag(messages: list[dict[str, str]], *, config: RAGLiteConfig) -> if tools and config.llm.startswith("llama-cpp-python"): # Help llama.cpp LLMs plan their response by providing a JSON schema for the tool call. clipped_messages[-1]["content"] += ( - f"\n\nDecide whether to use or skip these tools in your response:\n{json.dumps(tools)}" + "\n\n\n" + f"Available tools:\n```\n{json.dumps(tools)}\n```\n" + "IMPORTANT: You MUST set skip=true and query=null if you can provide a satisfactory response without searching the knowledge base.\n" + "" ) async_stream = await acompletion( model=config.llm, diff --git a/tests/test_extract.py b/tests/test_extract.py index 2bdf07b..9cf8dd1 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -29,7 +29,7 @@ class LoginResponse(BaseModel): # Extract structured data. username, password = "cypher", "steak" login_response = extract_with_llm( - LoginResponse, f"{username} // {password}", strict=strict, config=config + LoginResponse, f"username: {username}\npassword: {password}", strict=strict, config=config ) # Validate the response. assert isinstance(login_response, LoginResponse) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7a391ba..edd60dc 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -44,7 +44,7 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: def test_rag_auto_without_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that does not require RAG. - user_prompt = "Is 7 a prime number?" + user_prompt = "Yes or no: is 'Veni, vidi, vici' a Latin phrase?" messages = [{"role": "user", "content": user_prompt}] stream = rag(messages, config=raglite_test_config) answer = ""