Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions alembic/versions/4fc61e385531_add_faq_and_chunk_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""add chunk_config, FAQ tables and FAQ config columns

Revision ID: 4fc61e385531
Revises: 1432eea7c5b9
Create Date: 2025-12-25 17:00:00.000000

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa

from db.op.safe_add import safe_add_column


# revision identifiers, used by Alembic.
revision: str = '4fc61e385531'
down_revision: Union[str, Sequence[str], None] = '1432eea7c5b9'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
# 1. Add chunk_config column to pai_knowledgebase_file table
safe_add_column('pai_knowledgebase_file', sa.Column('chunk_config', sa.JSON(), nullable=True))
safe_add_column('pai_chatbot_model', sa.Column('enable_faq', sa.Boolean(), nullable=True, default=False))
safe_add_column('pai_chatbot_model', sa.Column('faq_config', sa.JSON(), nullable=True))


def downgrade() -> None:
"""Downgrade schema."""
# Remove chunk_config column from pai_knowledgebase_file table
op.drop_column('pai_knowledgebase_file', 'chunk_config')
op.drop_column('pai_chatbot_model', 'faq_config')
op.drop_column('pai_chatbot_model', 'enable_faq')

27 changes: 27 additions & 0 deletions backend/agent/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common.chat.constants import MessageRole
from opentelemetry import trace
from utils.json_utils import parse_tool_arguments
from agent.tool_utils import check_and_handle_return_direct


@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
Expand Down Expand Up @@ -118,6 +119,19 @@ async def gen():

state.current_tool_call = None
observations += message_content + "\n\n"

# Check if tool has return_direct=True, if so, return directly
tool_obj = self.tool_fn_map[tool_name]
return_chunk = check_and_handle_return_direct(
tool_obj=tool_obj,
tool_name=tool_name,
tool_content=tool_content,
tool_error=tool_error,
agent_name=self.name,
)
if return_chunk:
yield return_chunk
return
act_prompt = self.build_prompt(state)
messages = [{"role": "system", "content": act_prompt}] + messages

Expand Down Expand Up @@ -212,6 +226,19 @@ async def gen():
)
observations += message_content + "\n\n"

# Check if tool has return_direct=True, if so, return directly
tool_obj = self.tool_fn_map[function_name]
return_chunk = check_and_handle_return_direct(
tool_obj=tool_obj,
tool_name=function_name,
tool_content=tool_content,
tool_error=tool_error,
agent_name=self.name,
)
if return_chunk:
yield return_chunk
return

# 超出步数保护
if react_step > self.max_steps:
logger.warning(f"Reached max recursion steps: {self.max_steps}")
Expand Down
14 changes: 14 additions & 0 deletions backend/agent/actor_with_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from extensions.trace.base import use_current_span
from opentelemetry import trace
from utils.json_utils import parse_tool_arguments
from agent.tool_utils import check_and_handle_return_direct


@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
Expand Down Expand Up @@ -157,6 +158,19 @@ async def gen():
result=tool_content,
error=tool_error
)

# Check if tool has return_direct=True, if so, return directly
tool_obj = self.tool_fn_map[function_name]
return_chunk = check_and_handle_return_direct(
tool_obj=tool_obj,
tool_name=function_name,
tool_content=tool_content,
tool_error=tool_error,
agent_name=self.name,
)
if return_chunk:
yield return_chunk
return
else:
break

Expand Down
60 changes: 60 additions & 0 deletions backend/agent/tool_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""Utility functions for handling tool calls and results."""

import json
from typing import Optional
from llama_index.core.tools.function_tool import FunctionTool
from common.llm.models import TextChunk
from loguru import logger


def check_and_handle_return_direct(
tool_obj: FunctionTool,
tool_name: str,
tool_content: Optional[str],
tool_error: Optional[str],
agent_name: str = "agent",
) -> Optional[TextChunk]:
"""
Check if a tool has return_direct=True and format the result accordingly.

Args:
tool_obj: The FunctionTool object
tool_name: Name of the tool
tool_content: Content returned by the tool (None if error)
tool_error: Error message if tool call failed (None if success)
agent_name: Name of the agent (for logging)

Returns:
TextChunk if return_direct=True, None otherwise
"""
return_direct = getattr(tool_obj.metadata, 'return_direct', False)

if not return_direct:
return None

logger.info(f"[{agent_name}] Tool {tool_name} has return_direct=True, returning tool result directly.")

if tool_error:
return TextChunk(delta=f"工具调用失败: {tool_error}")

if not tool_content:
return TextChunk(delta="工具调用成功,但未返回内容。")

try:
result_data = json.loads(tool_content)
if isinstance(result_data, dict) and "result" in result_data:
# Format FAQ results or similar structured results
formatted_result = ""
for item in result_data.get("result", []):
if isinstance(item, dict):
content = item.get("content", "")
if content:
formatted_result += content + "\n\n"
if formatted_result:
return TextChunk(delta=formatted_result.strip())
else:
return TextChunk(delta=tool_content)
else:
return TextChunk(delta=tool_content)
except (json.JSONDecodeError, Exception):
return TextChunk(delta=tool_content)
Loading
Loading