Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to

- 🐛(front) fix target blank links in chat #103
- 🚑️(posthog) pass str instead of UUID for user PK #134
- ⚡️(web-search) keep running when tool call fails #137


## [0.0.7] - 2025-10-28
Expand Down
104 changes: 104 additions & 0 deletions src/backend/chat/agent_rag/document_rag_backends/albert_rag_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from django.conf import settings

import httpx
import requests

from chat.agent_rag.albert_api_constants import Searches
Expand Down Expand Up @@ -65,6 +66,27 @@ def create_collection(self, name: str, description: Optional[str] = None) -> str
self.collection_id = str(response.json()["id"])
return self.collection_id

async def acreate_collection(self, name: str, description: Optional[str] = None) -> str:
"""
Create a temporary collection for the search operation.
This method should handle the logic to create or retrieve an existing collection.
"""
async with httpx.AsyncClient(timeout=settings.ALBERT_API_TIMEOUT) as client:
response = await client.post(
self._collections_endpoint,
headers=self._headers,
json={
"name": name,
"description": description or self._default_collection_description,
"visibility": "private",
},
timeout=settings.ALBERT_API_TIMEOUT,
)
response.raise_for_status()

self.collection_id = str(response.json()["id"])
return self.collection_id

def delete_collection(self) -> None:
"""
Delete the current collection
Expand All @@ -76,6 +98,18 @@ def delete_collection(self) -> None:
)
response.raise_for_status()

async def adelete_collection(self) -> None:
"""
Asynchronously delete the current collection
"""
async with httpx.AsyncClient(timeout=settings.ALBERT_API_TIMEOUT) as client:
response = await client.delete(
urljoin(f"{self._collections_endpoint}/", self.collection_id),
headers=self._headers,
timeout=settings.ALBERT_API_TIMEOUT,
)
response.raise_for_status()

def parse_pdf_document(self, name: str, content_type: str, content: BytesIO) -> str:
"""
Parse the PDF document content and return the text content.
Expand Down Expand Up @@ -150,6 +184,31 @@ def store_document(self, name: str, content: str) -> None:
logger.debug(response.json())
response.raise_for_status()

async def astore_document(self, name: str, content: str) -> None:
"""
Store the document content in the Albert collection.
This method should handle the logic to send the document content to the Albert API.

Args:
name (str): The name of the document.
content (str): The content of the document in Markdown format.
"""
async with httpx.AsyncClient(timeout=settings.ALBERT_API_TIMEOUT) as client:
response = await client.post(
urljoin(self._base_url, self._documents_endpoint),
headers=self._headers,
files={
"file": (f"{name}.md", BytesIO(content.encode("utf-8")), "text/markdown"),
},
data={
"collection": int(self.collection_id),
"metadata": json.dumps({"document_name": name}), # undocumented API
},
timeout=settings.ALBERT_API_TIMEOUT,
)
logger.debug(response.json())
response.raise_for_status()
Comment on lines +187 to +210
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix error handling order and remove redundant timeout.

Two issues:

  1. Critical - Wrong error handling order: Lines 209-210 call response.json() before raise_for_status(). If the API returns an error status, attempting to parse JSON first may succeed with error details, but the subsequent status check will raise anyway. Check the status before attempting to parse the response.

  2. Redundant timeout: Line 207 specifies timeout again, already configured in the AsyncClient constructor at line 196.

Apply this diff:

     async def astore_document(self, name: str, content: str) -> None:
         """
         Store the document content in the Albert collection.
         This method should handle the logic to send the document content to the Albert API.
 
         Args:
             name (str): The name of the document.
             content (str): The content of the document in Markdown format.
         """
         async with httpx.AsyncClient(timeout=settings.ALBERT_API_TIMEOUT) as client:
             response = await client.post(
                 urljoin(self._base_url, self._documents_endpoint),
                 headers=self._headers,
                 files={
                     "file": (f"{name}.md", BytesIO(content.encode("utf-8")), "text/markdown"),
                 },
                 data={
                     "collection": int(self.collection_id),
                     "metadata": json.dumps({"document_name": name}),  # undocumented API
                 },
-                timeout=settings.ALBERT_API_TIMEOUT,
             )
-            logger.debug(response.json())
             response.raise_for_status()
+            logger.debug(response.json())
🤖 Prompt for AI Agents
In src/backend/chat/agent_rag/document_rag_backends/albert_rag_backend.py around
lines 187 to 210, fix two issues: move response.raise_for_status() to run
immediately after the POST (before any response.json() calls) so HTTP errors are
detected before parsing, and remove the redundant timeout parameter from the
client.post() call (the AsyncClient already has timeout configured); keep
logging the parsed JSON only after raise_for_status() succeeds.


def search(self, query, results_count: int = 4) -> RAGWebResults:
"""
Perform a search using the Albert API based on the provided query.
Expand Down Expand Up @@ -190,3 +249,48 @@ def search(self, query, results_count: int = 4) -> RAGWebResults:
completion_tokens=searches.usage.completion_tokens,
),
)

async def asearch(self, query, results_count: int = 4) -> RAGWebResults:
"""
Perform an asynchronous search using the Albert API based on the provided query.

Args:
query (str): The search query.
results_count (int): The number of results to return.

Returns:
RAGWebResults: The search results.
"""
async with httpx.AsyncClient(timeout=settings.ALBERT_API_TIMEOUT) as client:
response = await client.post(
urljoin(self._base_url, self._search_endpoint),
headers=self._headers,
json={
"collections": [int(self.collection_id)],
"prompt": query,
"score_threshold": 0.6,
"k": results_count, # Number of chunks to return from the search
},
timeout=settings.ALBERT_API_TIMEOUT,
)

logger.debug("Search response: %s %s", response.text, response.status_code)

response.raise_for_status()

searches = Searches(**response.json())

return RAGWebResults(
data=[
RAGWebResult(
url=result.chunk.metadata["document_name"],
content=result.chunk.content,
score=result.score,
)
for result in searches.data
],
usage=RAGWebUsage(
prompt_tokens=searches.usage.prompt_tokens,
completion_tokens=searches.usage.completion_tokens,
),
)
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Implementation of the Albert API for RAG document search."""

import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from io import BytesIO
from typing import Optional

from asgiref.sync import sync_to_async

from chat.agent_rag.constants import RAGWebResults

logger = logging.getLogger(__name__)
Expand All @@ -25,6 +27,13 @@ def create_collection(self, name: str, description: Optional[str] = None) -> str
"""
raise NotImplementedError("Must be implemented in subclass.")

async def acreate_collection(self, name: str, description: Optional[str] = None) -> str:
"""
Create a temporary collection for the search operation.
This method should handle the logic to create or retrieve an existing collection.
"""
return await sync_to_async(self.create_collection)(name=name, description=description)

def parse_document(self, name: str, content_type: str, content: BytesIO):
"""
Parse the document and prepare it for the search operation.
Expand All @@ -43,15 +52,26 @@ def parse_document(self, name: str, content_type: str, content: BytesIO):

def store_document(self, name: str, content: str) -> None:
"""
Store the document content in the Albert collection.
This method should handle the logic to send the document content to the Albert API.
Store the document content in the collection.
This method should handle the logic to send the document content to the API.

Args:
name (str): The name of the document.
content (str): The content of the document in Markdown format.
"""
raise NotImplementedError("Must be implemented in subclass.")

async def astore_document(self, name: str, content: str) -> None:
"""
Store the document content in the collection.
This method should handle the logic to send the document content to the API.

Args:
name (str): The name of the document.
content (str): The content of the document in Markdown format.
"""
return await sync_to_async(self.store_document)(name=name, content=content)

def parse_and_store_document(self, name: str, content_type: str, content: BytesIO) -> str:
"""
Parse the document and store it in the Albert collection.
Expand All @@ -75,12 +95,25 @@ def delete_collection(self) -> None:
"""
raise NotImplementedError("Must be implemented in subclass.")

async def adelete_collection(self) -> None:
"""
Delete the collection.
This method should handle the logic to delete the collection from the backend.
"""
return await sync_to_async(self.delete_collection)()

def search(self, query, results_count: int = 4) -> RAGWebResults:
"""
Search the collection for the given query.
"""
raise NotImplementedError("Must be implemented in subclass.")

async def asearch(self, query, results_count: int = 4) -> RAGWebResults:
"""
Search the collection for the given query.
"""
return await sync_to_async(self.search)(query=query, results_count=results_count)

@classmethod
@contextmanager
def temporary_collection(cls, name: str, description: Optional[str] = None):
Expand All @@ -92,3 +125,15 @@ def temporary_collection(cls, name: str, description: Optional[str] = None):
yield backend
finally:
backend.delete_collection()

@classmethod
@asynccontextmanager
async def temporary_collection_async(cls, name: str, description: Optional[str] = None):
"""Context manager for RAG backend with temporary collections."""
backend = cls()

await backend.acreate_collection(name=name, description=description)
try:
yield backend
finally:
await backend.adelete_collection()
154 changes: 154 additions & 0 deletions src/backend/chat/tests/tools/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""Tests for chat tool utilities."""

import inspect
from typing import get_type_hints

import pytest
from pydantic_ai import ModelRetry, RunContext

from chat.tools.exceptions import ModelCannotRetry
from chat.tools.utils import last_model_retry_soft_fail


def test_last_model_retry_soft_fail_preserves_function_metadata():
"""Test that the decorator preserves function metadata for schema generation."""

@last_model_retry_soft_fail
async def example_tool(ctx: RunContext, query: str, limit: int = 10) -> str: # pylint: disable=unused-argument
"""
Example tool function.

Args:
ctx: The run context.
query: The search query.
limit: Maximum number of results.

Returns:
The search results.
"""
return f"Results for {query} (limit: {limit})"

# Check that function name is preserved
assert example_tool.__name__ == "example_tool"

# Check that docstring is preserved
assert example_tool.__doc__ is not None
assert "Example tool function" in example_tool.__doc__

# Check that signature is preserved
sig = inspect.signature(example_tool)
assert "ctx" in sig.parameters
assert "query" in sig.parameters
assert "limit" in sig.parameters
assert sig.parameters["limit"].default == 10

# Check that type hints are preserved
type_hints = get_type_hints(example_tool)
assert "query" in type_hints
assert type_hints["query"] == str
assert "limit" in type_hints
assert type_hints["limit"] == int
assert type_hints["return"] == str


@pytest.mark.asyncio
async def test_last_model_retry_soft_fail_normal_execution():
"""Test that the decorator doesn't interfere with normal execution."""

@last_model_retry_soft_fail
async def example_tool(_ctx: RunContext, value: str) -> str:
"""Example tool."""
return f"Result: {value}"

# Create a mock context
class MockContext:
"""Fake context for testing."""

max_retries = 3
retries = {}
tool_name = "example_tool"

ctx = MockContext()
result = await example_tool(ctx, "test")
assert result == "Result: test"


@pytest.mark.asyncio
async def test_last_model_retry_soft_fail_handles_retry_exception():
"""Test that the decorator handles ModelRetry exceptions correctly."""

@last_model_retry_soft_fail
async def failing_tool(_ctx: RunContext, should_fail: bool) -> str:
"""Tool that can raise ModelRetry."""
if should_fail:
raise ModelRetry("Please retry with different parameters")
return "Success"

# Create a mock context
class MockContext:
"""Fake context for testing."""

max_retries = 3
retries = {}
tool_name = "failing_tool"

ctx = MockContext()

# Test when retries haven't been exhausted - should re-raise
with pytest.raises(ModelRetry):
await failing_tool(ctx, should_fail=True)


@pytest.mark.asyncio
async def test_last_model_retry_soft_fail_returns_message_when_max_retries_reached():
"""Test that the decorator returns the error message when max retries is reached."""

@last_model_retry_soft_fail
async def failing_tool(_ctx: RunContext, should_fail: bool) -> str:
"""Tool that can raise ModelRetry."""
if should_fail:
raise ModelRetry("Please retry with different parameters.")
return "Success"

# Create a mock context with max retries already reached
class MockContext:
"""Fake context for testing."""

max_retries = 3
retries = {"failing_tool": 3}
tool_name = "failing_tool"

ctx = MockContext()

# Test when retries have been exhausted - should return message
result = await failing_tool(ctx, should_fail=True)
assert result == (
"Please retry with different parameters. "
"You must explain this to the user and not try to answer based on your knowledge."
)


@pytest.mark.asyncio
async def test_last_model_retry_soft_fail_returns_message_when_model_cannot_retry():
"""Test that the decorator returns the error message when ModelCannotRetry is raised."""

@last_model_retry_soft_fail
async def failing_tool(_ctx: RunContext, should_fail: bool) -> str:
"""Tool that can raise ModelRetry."""
if should_fail:
raise ModelCannotRetry("This is broken duh.")
return "Success"

# Create a mock context with max retries already reached
class MockContext:
"""Fake context for testing."""

max_retries = 3
retries = {"failing_tool": 3}
tool_name = "failing_tool"

ctx = MockContext()

# Test when retries have been exhausted - should return message
result = await failing_tool(ctx, should_fail=True)
assert result == "This is broken duh."
Loading
Loading