diff --git a/backend/main.py b/backend/main.py index 7f1e2dbea7e..4688a437b6d 100644 --- a/backend/main.py +++ b/backend/main.py @@ -336,6 +336,19 @@ def get_tools_function_calling_payload(messages, task_model_id, content): "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, } +def fixup_response_content(content): + try: + # Attempt to parse the entire content as JSON + json.loads(content) + return content + except json.JSONDecodeError: + # If parsing fails, try to parse only the first line + first_line = content.split('\n')[0].strip() + try: + json.loads(first_line) + return first_line + except json.JSONDecodeError: + return content async def get_content_from_response(response) -> Optional[str]: content = None @@ -349,6 +362,7 @@ async def get_content_from_response(response) -> Optional[str]: await response.background() else: content = response["choices"][0]["message"]["content"] + content = fixup_response_content(content) return content @@ -357,9 +371,7 @@ async def chat_completion_tools_handler( ) -> tuple[dict, dict]: # If tool_ids field is present, call the functions metadata = body.get("metadata", {}) - - tool_ids = metadata.get("tool_ids", None) - log.debug(f"{tool_ids=}") + tool_ids = metadata.get("tool_ids", {}) if not tool_ids: return body, {} @@ -379,7 +391,7 @@ async def chat_completion_tools_handler( "__files__": metadata.get("files", []), }, ) - log.info(f"{tools=}") + log.debug(f"{tools=}") specs = [tool["spec"] for tool in tools.values()] tools_specs = json.dumps(specs) @@ -387,7 +399,7 @@ async def chat_completion_tools_handler( tools_function_calling_prompt = tools_function_calling_generation_template( app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs ) - log.info(f"{tools_function_calling_prompt=}") + log.debug(f"{tools_function_calling_prompt=}") payload = get_tools_function_calling_payload( body["messages"], task_model_id, tools_function_calling_prompt ) @@ -399,21 +411,18 @@ async def chat_completion_tools_handler( try: response = await generate_chat_completions(form_data=payload, user=user) - log.debug(f"{response=}") content = await get_content_from_response(response) - log.debug(f"{content=}") if not content: return body, {} - result = json.loads(content) + result = json.loads(content) if type(content) == str else content tool_function_name = result.get("name", None) if tool_function_name not in tools: return body, {} tool_function_params = result.get("parameters", {}) - try: tool_output = await tools[tool_function_name]["callable"]( **tool_function_params @@ -435,14 +444,19 @@ async def chat_completion_tools_handler( skip_files = True if isinstance(tool_output, str): - contexts.append(tool_output) + # the function name and output will be exposed to the model + contexts.append( + { + "tool": tool_function_name, + "params": tool_function_params, + "output": tool_output + } + ) except Exception as e: log.exception(f"Error: {e}") content = None - log.debug(f"tool_contexts: {contexts}") - if skip_files and "files" in body.get("metadata", {}): del body["metadata"]["files"] @@ -487,6 +501,8 @@ async def get_body_and_model_and_user(request): raise Exception("Model not found") model = app.state.MODELS[model_id] + # tool ids aren't present for some models + body["tool_ids"] = model.get('info', {}).get('meta', {}).get('toolIds', []) user = get_current_user( request, get_http_authorization_cred(request.headers.get("Authorization")), @@ -568,13 +584,13 @@ async def dispatch(self, request: Request, call_next): # If context is not empty, insert it into the messages if len(contexts) > 0: - context_string = "/n".join(contexts).strip() prompt = get_last_user_message(body["messages"]) if prompt is None: raise Exception("No user message found") # Workaround for Ollama 2.0+ system prompt issue # TODO: replace with add_or_update_system_message if model["owned_by"] == "ollama": + context_string = "/n".join(contexts).strip() body["messages"] = prepend_to_first_user_message_content( rag_template( rag_app.state.config.RAG_TEMPLATE, context_string, prompt @@ -582,11 +598,13 @@ async def dispatch(self, request: Request, call_next): body["messages"], ) else: - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], + # add a message for the model to see the tool calls and results + context_string = "\n".join(map(lambda x: x["tool"] + ": " + x["output"], contexts)) + body["messages"].append( + { + "role": "assistant", + "content": f"Tool calls and results: {context_string}", + } ) # If there are citations, add them to the data_items @@ -1337,8 +1355,6 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u @app.post("/api/task/title/completions") async def generate_title(form_data: dict, user=Depends(get_verified_user)): - print("generate_title") - model_id = form_data["model"] if model_id not in app.state.MODELS: raise HTTPException(