Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 36 additions & 20 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,19 @@ def get_tools_function_calling_payload(messages, task_model_id, content):
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}

def fixup_response_content(content):
try:
# Attempt to parse the entire content as JSON
json.loads(content)
return content
except json.JSONDecodeError:
# If parsing fails, try to parse only the first line
first_line = content.split('\n')[0].strip()
try:
json.loads(first_line)
return first_line
except json.JSONDecodeError:
return content

async def get_content_from_response(response) -> Optional[str]:
content = None
Expand All @@ -349,6 +362,7 @@ async def get_content_from_response(response) -> Optional[str]:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
content = fixup_response_content(content)
return content


Expand All @@ -357,9 +371,7 @@ async def chat_completion_tools_handler(
) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions
metadata = body.get("metadata", {})

tool_ids = metadata.get("tool_ids", None)
log.debug(f"{tool_ids=}")
tool_ids = metadata.get("tool_ids", {})
if not tool_ids:
return body, {}

Expand All @@ -379,15 +391,15 @@ async def chat_completion_tools_handler(
"__files__": metadata.get("files", []),
},
)
log.info(f"{tools=}")
log.debug(f"{tools=}")

specs = [tool["spec"] for tool in tools.values()]
tools_specs = json.dumps(specs)

tools_function_calling_prompt = tools_function_calling_generation_template(
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, tools_specs
)
log.info(f"{tools_function_calling_prompt=}")
log.debug(f"{tools_function_calling_prompt=}")
payload = get_tools_function_calling_payload(
body["messages"], task_model_id, tools_function_calling_prompt
)
Expand All @@ -399,21 +411,18 @@ async def chat_completion_tools_handler(

try:
response = await generate_chat_completions(form_data=payload, user=user)
log.debug(f"{response=}")
content = await get_content_from_response(response)
log.debug(f"{content=}")

if not content:
return body, {}

result = json.loads(content)
result = json.loads(content) if type(content) == str else content

tool_function_name = result.get("name", None)
if tool_function_name not in tools:
return body, {}

tool_function_params = result.get("parameters", {})

try:
tool_output = await tools[tool_function_name]["callable"](
**tool_function_params
Expand All @@ -435,14 +444,19 @@ async def chat_completion_tools_handler(
skip_files = True

if isinstance(tool_output, str):
contexts.append(tool_output)
# the function name and output will be exposed to the model
contexts.append(
{
"tool": tool_function_name,
"params": tool_function_params,
"output": tool_output
}
)

except Exception as e:
log.exception(f"Error: {e}")
content = None

log.debug(f"tool_contexts: {contexts}")

if skip_files and "files" in body.get("metadata", {}):
del body["metadata"]["files"]

Expand Down Expand Up @@ -487,6 +501,8 @@ async def get_body_and_model_and_user(request):
raise Exception("Model not found")
model = app.state.MODELS[model_id]

# tool ids aren't present for some models
body["tool_ids"] = model.get('info', {}).get('meta', {}).get('toolIds', [])
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
Expand Down Expand Up @@ -568,25 +584,27 @@ async def dispatch(self, request: Request, call_next):

# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
if prompt is None:
raise Exception("No user message found")
# Workaround for Ollama 2.0+ system prompt issue
# TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama":
context_string = "/n".join(contexts).strip()
body["messages"] = prepend_to_first_user_message_content(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
else:
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
# add a message for the model to see the tool calls and results
context_string = "\n".join(map(lambda x: x["tool"] + ": " + x["output"], contexts))
body["messages"].append(
{
"role": "assistant",
"content": f"Tool calls and results: {context_string}",
}
)

# If there are citations, add them to the data_items
Expand Down Expand Up @@ -1337,8 +1355,6 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u

@app.post("/api/task/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")

model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
Expand Down