diff --git a/alembic/versions/4fc61e385531_add_faq_and_chunk_config.py b/alembic/versions/4fc61e385531_add_faq_and_chunk_config.py new file mode 100644 index 000000000..9fa924fc6 --- /dev/null +++ b/alembic/versions/4fc61e385531_add_faq_and_chunk_config.py @@ -0,0 +1,37 @@ +"""add chunk_config, FAQ tables and FAQ config columns + +Revision ID: 4fc61e385531 +Revises: 1432eea7c5b9 +Create Date: 2025-12-25 17:00:00.000000 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from db.op.safe_add import safe_add_column + + +# revision identifiers, used by Alembic. +revision: str = '4fc61e385531' +down_revision: Union[str, Sequence[str], None] = '1432eea7c5b9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # 1. Add chunk_config column to pai_knowledgebase_file table + safe_add_column('pai_knowledgebase_file', sa.Column('chunk_config', sa.JSON(), nullable=True)) + safe_add_column('pai_chatbot_model', sa.Column('enable_faq', sa.Boolean(), nullable=True, default=False)) + safe_add_column('pai_chatbot_model', sa.Column('faq_config', sa.JSON(), nullable=True)) + + +def downgrade() -> None: + """Downgrade schema.""" + # Remove chunk_config column from pai_knowledgebase_file table + op.drop_column('pai_knowledgebase_file', 'chunk_config') + op.drop_column('pai_chatbot_model', 'faq_config') + op.drop_column('pai_chatbot_model', 'enable_faq') + diff --git a/backend/agent/actor.py b/backend/agent/actor.py index 70e1bc30f..a1f758f9d 100644 --- a/backend/agent/actor.py +++ b/backend/agent/actor.py @@ -14,6 +14,7 @@ from common.chat.constants import MessageRole from opentelemetry import trace from utils.json_utils import parse_tool_arguments +from agent.tool_utils import check_and_handle_return_direct @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) @@ -118,6 +119,19 @@ async def gen(): state.current_tool_call = None observations += message_content + "\n\n" + + # Check if tool has return_direct=True, if so, return directly + tool_obj = self.tool_fn_map[tool_name] + return_chunk = check_and_handle_return_direct( + tool_obj=tool_obj, + tool_name=tool_name, + tool_content=tool_content, + tool_error=tool_error, + agent_name=self.name, + ) + if return_chunk: + yield return_chunk + return act_prompt = self.build_prompt(state) messages = [{"role": "system", "content": act_prompt}] + messages @@ -212,6 +226,19 @@ async def gen(): ) observations += message_content + "\n\n" + # Check if tool has return_direct=True, if so, return directly + tool_obj = self.tool_fn_map[function_name] + return_chunk = check_and_handle_return_direct( + tool_obj=tool_obj, + tool_name=function_name, + tool_content=tool_content, + tool_error=tool_error, + agent_name=self.name, + ) + if return_chunk: + yield return_chunk + return + # 超出步数保护 if react_step > self.max_steps: logger.warning(f"Reached max recursion steps: {self.max_steps}") diff --git a/backend/agent/actor_with_plan.py b/backend/agent/actor_with_plan.py index 440796b98..16dc5c0e0 100644 --- a/backend/agent/actor_with_plan.py +++ b/backend/agent/actor_with_plan.py @@ -10,6 +10,7 @@ from extensions.trace.base import use_current_span from opentelemetry import trace from utils.json_utils import parse_tool_arguments +from agent.tool_utils import check_and_handle_return_direct @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) @@ -157,6 +158,19 @@ async def gen(): result=tool_content, error=tool_error ) + + # Check if tool has return_direct=True, if so, return directly + tool_obj = self.tool_fn_map[function_name] + return_chunk = check_and_handle_return_direct( + tool_obj=tool_obj, + tool_name=function_name, + tool_content=tool_content, + tool_error=tool_error, + agent_name=self.name, + ) + if return_chunk: + yield return_chunk + return else: break diff --git a/backend/agent/tool_utils.py b/backend/agent/tool_utils.py new file mode 100644 index 000000000..580f0eb36 --- /dev/null +++ b/backend/agent/tool_utils.py @@ -0,0 +1,60 @@ +"""Utility functions for handling tool calls and results.""" + +import json +from typing import Optional +from llama_index.core.tools.function_tool import FunctionTool +from common.llm.models import TextChunk +from loguru import logger + + +def check_and_handle_return_direct( + tool_obj: FunctionTool, + tool_name: str, + tool_content: Optional[str], + tool_error: Optional[str], + agent_name: str = "agent", +) -> Optional[TextChunk]: + """ + Check if a tool has return_direct=True and format the result accordingly. + + Args: + tool_obj: The FunctionTool object + tool_name: Name of the tool + tool_content: Content returned by the tool (None if error) + tool_error: Error message if tool call failed (None if success) + agent_name: Name of the agent (for logging) + + Returns: + TextChunk if return_direct=True, None otherwise + """ + return_direct = getattr(tool_obj.metadata, 'return_direct', False) + + if not return_direct: + return None + + logger.info(f"[{agent_name}] Tool {tool_name} has return_direct=True, returning tool result directly.") + + if tool_error: + return TextChunk(delta=f"工具调用失败: {tool_error}") + + if not tool_content: + return TextChunk(delta="工具调用成功,但未返回内容。") + + try: + result_data = json.loads(tool_content) + if isinstance(result_data, dict) and "result" in result_data: + # Format FAQ results or similar structured results + formatted_result = "" + for item in result_data.get("result", []): + if isinstance(item, dict): + content = item.get("content", "") + if content: + formatted_result += content + "\n\n" + if formatted_result: + return TextChunk(delta=formatted_result.strip()) + else: + return TextChunk(delta=tool_content) + else: + return TextChunk(delta=tool_content) + except (json.JSONDecodeError, Exception): + return TextChunk(delta=tool_content) diff --git a/backend/api/v1/config_apis/chatapp.py b/backend/api/v1/config_apis/chatapp.py index 87ea7e0ce..6ee6e54d9 100644 --- a/backend/api/v1/config_apis/chatapp.py +++ b/backend/api/v1/config_apis/chatapp.py @@ -1,24 +1,502 @@ ### Embedding configuration API ### -from fastapi import APIRouter, Depends, Query +import time +import asyncio +import tempfile +import os +from datetime import datetime, timezone +from fastapi import APIRouter, Depends, Query, File, UploadFile, Form from sqlmodel import select, func from sqlmodel.ext.asyncio.session import AsyncSession from db.models.chatbot import ( ChatBotCreate, ChatBotEntity, ) +from pairag.file.models.file_item import FileItem +import uuid +import hashlib +from rag.kb_file_client import kb_file_client +from rag.parse_utils import sanitize_text +from tqdm import tqdm from db.db_context import get_db_session from sqlalchemy.exc import IntegrityError from common.chat.response_model import PagedResult, ResponseModel, success_response -from api.v1.utils.paginate import get_pagination_meta -from service.injection import get_chatapp_service, get_tenant_id +from service.injection import get_chatapp_service, get_tenant_id, get_faq_config_service, get_faq_item_service, get_rag_service, get_embedding_service, get_knowledgebase_service, get_file_service from service.tool.chatapp_service import ChatappService +from service.tool.faq_config_service import FAQConfigService +from service.tool.faq_item_service import FAQItemService +from service.injection import get_faq_item_service +from db.models.faq_item import FAQItemCreate +from service.knowledgebase.rag_service import RagService +from service.knowledgebase.knowledgebase_service import KnowledgebaseService +from service.knowledgebase.file_service import FileService +from service.model.embedding_service import EmbeddingService +from db.models.knowledgebase.knowledgebase import KnowledgebaseCreate, RetrievalConfig, ChunkConfig, TableParserConfig +from common.knowledgebase.constants import FAQ_KNOWLEDGEBASE_NAME, DEFAULT_FAQ_SIMILARITY_THRESHOLD +from common.knowledgebase.types import VectorIndexRetrievalType, FileStatus +from rag.file_item_utils import to_file_entity +from typing import Optional, List +from io import BytesIO +import json from api.api_exception import ApiException import traceback from loguru import logger app_router = APIRouter() +# Import FAQ dependencies +from db.models.faq_config import FAQConfigCreate +from db.models.faq_item import FAQItemCreate, FAQItemEntity + +# FAQ routes - MUST be defined before /{id} routes to avoid route conflicts +# FastAPI matches routes in order, so more specific routes must come first +@app_router.post("/{app_id}/faqs", response_model=ResponseModel[FAQItemEntity], tags=["FAQ"]) +async def create_faq_item( + app_id: str, + faq_item_create: FAQItemCreate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + chatapp_service: ChatappService = Depends(get_chatapp_service), + faq_item_service: FAQItemService = Depends(get_faq_item_service), + rag_service: RagService = Depends(get_rag_service), +): + logger.info(f"Creating FAQ item for app_id: {app_id}") + try: + # Get chatbot by app_id to get chatbot_id + chatbot = await chatapp_service.get_chatapp_by_app_id(app_id=app_id, tenant_id=tenant_id) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{app_id}' 不存在。") + + if not chatbot.enable_faq or not chatbot.faq_config: + raise ApiException(code=400, message="请先启用FAQ功能。") + + faq_item = await faq_item_service.create_faq_item( + chatbot_id=chatbot.app_id, + faq_item_data=faq_item_create, + tenant_id=tenant_id, + ) + await faq_item_service.save_faq_to_knowledgebase(faq_item, tenant_id, rag_service) + await session.commit() + await session.refresh(faq_item) + return success_response(data=faq_item, message="创建FAQ成功。") + except ValueError as e: + logger.error(f"Failed to create FAQ item: {str(e)}") + raise ApiException(code=400, message=str(e)) + except ApiException: + raise + except Exception as e: + logger.error(f"Failed to create FAQ item: {traceback.format_exc()}") + raise ApiException(code=500, message=f"创建FAQ失败: {traceback.format_exc()}") + +@app_router.get("/{app_id}/faqs", tags=["FAQ"]) +async def list_faq_items( + app_id: str, + page: int = Query(default=1, ge=1), + size: int = Query(default=100, le=1000), + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + chatapp_service: ChatappService = Depends(get_chatapp_service), + faq_item_service: FAQItemService = Depends(get_faq_item_service), +): + logger.info(f"Listing FAQ items for app_id: {app_id}") + try: + # Get chatbot by app_id to get chatbot_id + chatbot = await chatapp_service.get_chatapp_by_app_id(app_id=app_id, tenant_id=tenant_id) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{app_id}' 不存在。") + + faq_items = await faq_item_service.list_faq_items( + chatbot_id=chatbot.app_id, + tenant_id=tenant_id, + page=page, + size=size, + ) + return success_response(data=faq_items, message="查询FAQ列表成功。") + except ValueError as e: + logger.error(f"Failed to list FAQ items: {str(e)}") + raise ApiException(code=400, message=str(e)) + except ApiException: + raise + except Exception as e: + logger.error(f"Failed to list FAQ items: {traceback.format_exc()}") + raise ApiException(code=500, message=f"查询FAQ列表失败: {traceback.format_exc()}") + +@app_router.put("/{app_id}/faqs/{faq_item_id}", response_model=ResponseModel[FAQItemEntity], tags=["FAQ"]) +async def update_faq_item( + app_id: str, + faq_item_id: str, + faq_item_update: FAQItemCreate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + faq_item_service: FAQItemService = Depends(get_faq_item_service), + rag_service: RagService = Depends(get_rag_service), +): + try: + faq_item = await faq_item_service.update_faq_item( + id=faq_item_id, update_data=faq_item_update, tenant_id=tenant_id, rag_service=rag_service + ) + await session.commit() + await session.refresh(faq_item) + return success_response(data=faq_item, message="更新FAQ成功。") + except ValueError as e: + logger.error(f"Failed to update FAQ item: {str(e)}") + raise ApiException(code=400, message=str(e)) + except Exception as e: + logger.error(f"Failed to update FAQ item: {traceback.format_exc()}") + raise ApiException(code=500, message=f"更新FAQ失败: {traceback.format_exc()}") + +@app_router.delete("/{app_id}/faqs/{faq_item_id}", tags=["FAQ"]) +async def delete_faq_item( + app_id: str, + faq_item_id: str, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + faq_item_service: FAQItemService = Depends(get_faq_item_service), + rag_service: RagService = Depends(get_rag_service), +): + try: + await faq_item_service.delete_faq_item(id=faq_item_id, tenant_id=tenant_id, rag_service=rag_service) + await session.commit() + return success_response(message=f"FAQ'{faq_item_id}'删除成功。") + except ValueError as e: + logger.error(f"Failed to delete FAQ item: {str(e)}") + raise ApiException(code=400, message=str(e)) + except Exception as e: + logger.error(f"Failed to delete FAQ item: {traceback.format_exc()}") + raise ApiException(code=500, message=f"删除FAQ失败: {traceback.format_exc()}") + +@app_router.get("/{app_id}/faq-config", response_model=ResponseModel[FAQConfigCreate], tags=["FAQ"]) +async def get_faq_config( + app_id: str, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + chatapp_service: ChatappService = Depends(get_chatapp_service), + faq_config_service: FAQConfigService = Depends(get_faq_config_service), +): + """Get FAQ config for an app.""" + try: + # Get chatbot by app_id to get chatbot_id + chatbot = await chatapp_service.get_chatapp_by_app_id(app_id=app_id, tenant_id=tenant_id) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{app_id}' 不存在。") + + faq_config = await faq_config_service.get_or_create_faq_config( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + await session.commit() + return success_response(data=faq_config, message="获取FAQ配置成功。") + except ValueError as e: + logger.error(f"Failed to get FAQ config: {str(e)}") + raise ApiException(code=400, message=str(e)) + except ApiException: + raise + except Exception as e: + logger.error(f"Failed to get FAQ config: {traceback.format_exc()}") + raise ApiException(code=500, message=f"获取FAQ配置失败: {traceback.format_exc()}") + +@app_router.put("/{app_id}/faq-config", response_model=ResponseModel[FAQConfigCreate], tags=["FAQ"]) +async def update_faq_config( + app_id: str, + faq_config_data: FAQConfigCreate, + tenant_id: str = Depends(get_tenant_id), + session: AsyncSession = Depends(get_db_session), + chatapp_service: ChatappService = Depends(get_chatapp_service), + faq_config_service: FAQConfigService = Depends(get_faq_config_service), + knowledgebase_service: KnowledgebaseService = Depends(get_knowledgebase_service), +): + """Update FAQ config for an app.""" + try: + # Get chatbot by app_id to get chatbot_id + chatbot = await chatapp_service.get_chatapp_by_app_id(app_id=app_id, tenant_id=tenant_id) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{app_id}' 不存在。") + + # Update FAQ config with full synchronization logic + updated_faq_config = await faq_config_service.update_faq_config_with_sync( + app_id=app_id, + chatbot_id=chatbot.id, + update_data=faq_config_data, + tenant_id=tenant_id, + knowledgebase_service=knowledgebase_service + ) + + await session.commit() + return success_response(data=updated_faq_config, message="更新FAQ配置成功。") + except ValueError as e: + logger.error(f"Failed to update FAQ config: {str(e)}") + raise ApiException(code=400, message=str(e)) + except ApiException: + raise + except Exception as e: + logger.error(f"Failed to update FAQ config: {traceback.format_exc()}") + raise ApiException(code=500, message=f"更新FAQ配置失败: {traceback.format_exc()}") + +MAX_CHECK_ATTEMPTS = 100 +CHECK_INTERVAL = 3 + +@app_router.post("/{app_id}/faq-files", tags=["FAQ"]) +async def upload_faq_files( + app_id: str, + files: Optional[List[UploadFile]] = File(...), + table_config: Optional[str] = Form(None, description="JSON string of table_config, shared by all files in this upload"), + session: AsyncSession = Depends(get_db_session), + chatapp_service: ChatappService = Depends(get_chatapp_service), + embedding_service: EmbeddingService = Depends(get_embedding_service), + knowledgebase_service: KnowledgebaseService = Depends(get_knowledgebase_service), + rag_service: RagService = Depends(get_rag_service), + tenant_id: str = Depends(get_tenant_id), +): + """Upload and parse FAQ files directly without storing file entity.""" + knowledgebase = None + try: + if not files: + raise ApiException(code=400, message="没有上传任何文件。") + # Get chatbot by app_id to get chatbot_id + chatbot = await chatapp_service.get_chatapp_by_app_id(app_id=app_id, tenant_id=tenant_id) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{app_id}' 不存在。") + + if not chatbot.enable_faq or not chatbot.faq_config: + raise ApiException(code=400, message="请先启用FAQ功能。") + + # Get or create FAQ knowledgebase + kb_name = f"{app_id}_{FAQ_KNOWLEDGEBASE_NAME}" + knowledgebase = await knowledgebase_service.get_knowledgebase_by_name(kb_name, tenant_id=tenant_id) + default_embedding_config = await embedding_service.get_default_embedding(tenant_id=tenant_id) + + if not knowledgebase: + logger.info(f"Creating FAQ knowledgebase {kb_name} for tenant {tenant_id}") + faq_config_service = FAQConfigService(session) + faq_config = await faq_config_service.get_faq_config_by_chatbot_id( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + + if faq_config and faq_config.embedding_model: + embedding_model = faq_config.embedding_model + else: + embedding_model = default_embedding_config.model_id + + default_similarity_threshold = faq_config.similarity_threshold if faq_config else DEFAULT_FAQ_SIMILARITY_THRESHOLD + + retrieval_config = RetrievalConfig( + retrieval_mode=VectorIndexRetrievalType.vector, + top_k=1, + enable_rerank=False, + rerank_top_k=None, + vector_weight=1.0, + similarity_threshold=default_similarity_threshold, + ) + + kb_create = KnowledgebaseCreate( + name=kb_name, + description="FAQ知识库", + embedding_model=embedding_model, + retrieval_config=retrieval_config, + ) + knowledgebase = await knowledgebase_service.create_knowledgebase(kb_data=kb_create, tenant_id=tenant_id) + try: + await session.commit() + await session.refresh(knowledgebase) + await knowledgebase_service.write_cache_after_commit(knowledgebase, tenant_id) + logger.info(f"Created FAQ knowledgebase {knowledgebase.id} for tenant {tenant_id}") + except IntegrityError: + await session.rollback() + if knowledgebase: + await knowledgebase_service.delete_cache_on_rollback(knowledgebase.id, tenant_id, kb_create.name) + knowledgebase = await knowledgebase_service.get_knowledgebase_by_name(kb_name, tenant_id=tenant_id) + if not knowledgebase: + raise ApiException(code=500, message="无法创建或获取FAQ知识库: 并发创建冲突") + logger.info(f"Retrieved existing FAQ knowledgebase {knowledgebase.id} for tenant {tenant_id}") + except Exception: + await session.rollback() + if knowledgebase: + await knowledgebase_service.delete_cache_on_rollback(knowledgebase.id, tenant_id, kb_create.name) + raise + else: + logger.info(f"Found existing FAQ knowledgebase {knowledgebase.id} for tenant {tenant_id}") + + + + parsed_chunk_config = knowledgebase.chunk_config + parsed_chunk_config = ChunkConfig.model_validate(parsed_chunk_config) + + # Parse and validate table_config if provided + parsed_table_config = None + if table_config: + try: + table_config_dict = json.loads(table_config) + if not isinstance(table_config_dict, dict): + raise ValueError("table_config must be a JSON object") + parsed_table_config = TableParserConfig.model_validate(table_config_dict) + if parsed_chunk_config.table_config: + table_config_dict_merged = parsed_chunk_config.table_config.model_dump() + table_config_dict_merged.update(table_config_dict) + parsed_chunk_config.table_config = TableParserConfig.model_validate(table_config_dict_merged) + else: + parsed_chunk_config.table_config = parsed_table_config + except json.JSONDecodeError as e: + raise ApiException(code=400, message=f"table_config 格式错误: {e}") + except Exception as e: + raise ApiException(code=400, message=f"table_config 验证失败: {e}") + + + + # Process each file + response_data = [] + total_chunks = 0 + + for file in files: + temp_file_path = None + try: + file_content = await file.read() + file_content_io = BytesIO(file_content) + file_extension = "." + (file.filename.split(".")[-1] if "." in file.filename else "") + + # Create temporary file to store file content + temp_file = tempfile.NamedTemporaryFile( + mode='wb', + suffix=file_extension, + delete=False, + prefix=f"faq_upload_{uuid.uuid4().hex}_" + ) + temp_file_path = temp_file.name + temp_file.write(file_content) + temp_file.close() + + # Create FileItem for parsing + file_id = uuid.uuid4().hex + file_md5 = hashlib.md5(file_content).hexdigest() + + file_item = FileItem( + id=file_id, + file_path=temp_file_path, + file=file_content_io, + kb_id=knowledgebase.id, + file_extension=file_extension, + file_name=file.filename or f"faq_file_{file_id}", + file_md5=file_md5, + file_size=len(file_content), + tenant_id=tenant_id, + ) + + logger.info(f"Parsing FAQ file {file_item.file_name} from temporary path {temp_file_path}...") + file_entity = to_file_entity(file_item=file_item) + # Convert ChunkConfig object to dict for file_entity + file_entity.chunk_config = parsed_chunk_config.model_dump() + file_parser = await kb_file_client.create_file_parser(knowledgebase, file_entity) + documents, nodes = file_parser.parse(file_item, is_attachment=False) + + if not nodes: + logger.warning(f"No nodes parsed from file {file_item.file_name}.") + response_data.append({"file_name": file_item.file_name, "items_count": 0}) + continue + + # Sanitize text + for node in nodes: + node.text = sanitize_text(node.text) + + logger.info(f"Parsed {len(nodes)} documents from FAQ file {file_item.file_name}.") + + # Save FAQ items to database from node metadata + faq_item_service = FAQItemService(session) + saved_faq_count = 0 + + # Prepare FAQ items data from all nodes + faq_items_to_create = [] + for node in nodes: + question = node.metadata.get("question", "").strip() if node.metadata else "" + answer = node.metadata.get("answer", "").strip() if node.metadata else "" + + # Skip if both question and answer are empty + if not question and not answer: + logger.warning(f"Skipping node with empty question and answer from file {file_item.file_name}") + continue + + faq_item_data = FAQItemCreate( + question=question, + answer=answer, + chatbot_id=chatbot.app_id, + file_id=file_id, + active=True, + ) + faq_items_to_create.append(faq_item_data) + + if faq_items_to_create: + created_faq_items = [] + for faq_item_data in tqdm(faq_items_to_create, desc=f"Creating FAQ Items for file {file_item.file_name}"): + try: + faq_item = await faq_item_service.create_faq_item( + chatbot_id=chatbot.app_id, + faq_item_data=faq_item_data, + tenant_id=tenant_id, + ) + created_faq_items.append(faq_item) + except Exception as create_error: + logger.error(f"Failed to create FAQ item: {create_error}") + continue + + # Save to knowledgebase in parallel batches + if created_faq_items: + save_tasks = [ + faq_item_service.save_faq_to_knowledgebase(faq_item, tenant_id, rag_service) + for faq_item in created_faq_items + ] + + if save_tasks: + # Process save tasks in batches to avoid overwhelming the system + save_batch_size = 50 + for j in tqdm(range(0, len(save_tasks), save_batch_size), desc=f"Saving FAQ Items to KB for file {file_item.file_name}"): + save_batch = save_tasks[j:j + save_batch_size] + await asyncio.gather(*save_batch, return_exceptions=True) + + saved_faq_count = len(created_faq_items) + + # Commit all FAQ items + await session.commit() + + # Refresh all created items (skip if refresh fails) + for faq_item in created_faq_items: + try: + await session.refresh(faq_item) + except Exception as refresh_error: + logger.debug(f"Could not refresh FAQ item {faq_item.id} (may not be persistent): {refresh_error}") + + logger.info(f"Saved {saved_faq_count}/{len(nodes)} FAQ items to database from file {file_item.file_name}.") + + # Add successful result to response_data + response_data.append({ + "file_name": file_item.file_name, + "items_count": saved_faq_count + }) + total_chunks += saved_faq_count + except Exception as file_error: + logger.error(f"Failed to process FAQ file {file.filename}: {traceback.format_exc()}") + response_data.append({ + "file_name": file.filename, + "items_count": 0, + "error": str(file_error) + }) + # Continue processing other files even if one fails + finally: + # Clean up temporary file + if temp_file_path and os.path.exists(temp_file_path): + try: + os.unlink(temp_file_path) + logger.debug(f"Deleted temporary file: {temp_file_path}") + except Exception as cleanup_error: + logger.warning(f"Failed to delete temporary file {temp_file_path}: {cleanup_error}") + + await session.commit() + + logger.info(f"Uploaded {len(files)} FAQ files successfully, total chunks: {total_chunks}.") + return success_response( + data=response_data, + message=f"成功上传并解析 {len(files)} 个文件,共提取 {total_chunks} 个片段。" + ) + except Exception as e: + logger.error(f"Failed to process FAQ file: {traceback.format_exc()}") + raise ApiException(code=400, message=f"文件处理失败: {e}") + @app_router.post("", response_model=ResponseModel[ChatBotEntity]) async def create_chatbot( @@ -29,6 +507,8 @@ async def create_chatbot( ): try: chatbot = await chatapp_service.create_chatapp(app_data=chatbot_create, tenant_id=tenant_id) + await session.commit() + await session.refresh(chatbot) return success_response(data=chatbot, message="创建应用成功。") except ValueError as e: logger.error(f"Failed to create chatapp: {str(e)}") @@ -89,9 +569,11 @@ async def delete_chatbot( tenant_id: str = Depends(get_tenant_id), session: AsyncSession = Depends(get_db_session), chatapp_service: ChatappService = Depends(get_chatapp_service), + rag_service: RagService = Depends(get_rag_service), ): try: - await chatapp_service.delete_chatapp(id=id, tenant_id=tenant_id) + await chatapp_service.delete_chatapp(id=id, tenant_id=tenant_id, rag_service=rag_service) + await session.commit() return success_response(message=f"应用'{id}'删除成功。") except ValueError as e: logger.error(f"Failed to delete chatapp: {str(e)}") diff --git a/backend/api/v1/config_apis/knowledgebase.py b/backend/api/v1/config_apis/knowledgebase.py index 2c879bfa4..3bb7f6ee5 100644 --- a/backend/api/v1/config_apis/knowledgebase.py +++ b/backend/api/v1/config_apis/knowledgebase.py @@ -4,7 +4,8 @@ import traceback from typing import List, Optional from common.knowledgebase.types import FileStatus -from fastapi import APIRouter, Depends, File, Query, UploadFile, Form +from fastapi import APIRouter, Depends, File, Query, UploadFile, Form, Body +import json from pydantic import BaseModel, Field from sqlmodel.ext.asyncio.session import AsyncSession from rag.file_item_utils import to_file_entity @@ -13,6 +14,7 @@ from db.models.knowledgebase.knowledgebase import ( KbEntity, KnowledgebaseCreate, + ChunkConfig, ) from db.db_context import get_db_session from pairag.file.store.file_store_helper import file_store @@ -208,6 +210,13 @@ async def start_parse_task( if not kb_entity: raise ValueError(f"知识库 {kb_id} 不存在。") + # Validate chunk_config if provided + if parse_request.chunk_config: + try: + ChunkConfig.model_validate(parse_request.chunk_config) + except Exception as e: + raise ApiException(code=400, message=f"chunk_config 格式错误: {e}") + file_items = await upload_file_names_async(kb_id=kb_id, parse_tasks=parse_request.files, tenant_id=tenant_id) file_version = int(time.time()) @@ -228,13 +237,23 @@ async def start_parse_task( else: file_entity = to_file_entity(file_item=file_item) + # Update chunk_config if provided + if parse_request.chunk_config: + file_entity.chunk_config = parse_request.chunk_config + file_entity.file_version = file_version background_worker.enqueue_file_tasks.delay(file_entity.id, file_entity.file_version, is_attachment=False, tenant_id=tenant_id) session.add(file_entity) file_entities.append(file_entity) + await session.commit() + for file_entity in file_entities: + await session.refresh(file_entity) + logger.info(f"Uploaded {len(file_entities)} files successfully.") return success_response(data=file_entities, message="启动解析任务成功") + except ApiException: + raise except ValueError as e: logger.error(f"启动解析任务失败。\nValueError:{traceback.format_exc()}") raise ApiException(code=400, message=str(e)) @@ -250,15 +269,39 @@ async def upload_files( auto_parse: bool = Query(default=True), files: Optional[List[UploadFile]] = File(...), file_sources: Optional[List[str]] = Form(None), + chunk_config: Optional[str] = Form(None, description="JSON string of chunk_config, shared by all files in this upload"), tenant_id: str = Depends(get_tenant_id), session: AsyncSession = Depends(get_db_session), rag_service: RagService = Depends(get_rag_service), file_service: FileService = Depends(get_file_service), + knowledgebase_service: KnowledgebaseService = Depends(get_knowledgebase_service), ): try: file_version = int(time.time()) if not files: raise ApiException(code=400, message="没有上传任何文件。") + + + knowledgebase = await knowledgebase_service.get_knowledgebase(kb_id=kb_id, tenant_id=tenant_id) + if not knowledgebase: + raise ApiException.not_found(kb_id, "知识库") + + kb_chunk_config = knowledgebase.chunk_config + + + parsed_chunk_config = None + if chunk_config: + try: + parsed_chunk_config = json.loads(chunk_config) + if not isinstance(parsed_chunk_config, dict): + raise ValueError("chunk_config must be a JSON object") + # Validate chunk_config + ChunkConfig.model_validate(parsed_chunk_config) + except json.JSONDecodeError as e: + raise ApiException(code=400, message=f"chunk_config 格式错误: {e}") + except Exception as e: + raise ApiException(code=400, message=f"chunk_config 验证失败: {e}") + file_items = await upload_form_files_async(kb_id=kb_id, files=files, tenant_id=tenant_id) file_names = [file_item.file_name for file_item in file_items] @@ -281,10 +324,15 @@ async def upload_files( if file_sources: assert len(file_sources) == len(new_file_entities), "文件来源列表长度与文件列表长度不一致" - for i,file_entity in enumerate(new_file_entities): + # Apply chunk_config to all files if provided (shared by all files in this upload) + for i, file_entity in enumerate(new_file_entities): if file_sources: file_entity.file_source = file_sources[i] + # All files share the same chunk_config if provided + if parsed_chunk_config: + file_entity.chunk_config = parsed_chunk_config + if auto_parse: import app.worker as background_worker background_worker.enqueue_file_tasks.delay(file_entity.id, file_entity.file_version, is_attachment=False, tenant_id=tenant_id) @@ -292,8 +340,22 @@ async def upload_files( session.add(file_entity) + await session.commit() + + + for file_entity in new_file_entities: + await session.refresh(file_entity) + + + response_entities = [] + for file_entity in new_file_entities: + file_dict = file_entity.model_dump() + if not file_entity.chunk_config: + file_dict["chunk_config"] = kb_chunk_config + response_entities.append(file_dict) + logger.info(f"Uploaded {len(new_file_entities)} files successfully.") - return success_response(data=new_file_entities, message="上传文件成功") + return success_response(data=response_entities, message="上传文件成功") except ValueError as e: logger.error(f"上传文件失败。\nValueError:{e}") raise ApiException(code=400, message=str(e)) @@ -327,21 +389,44 @@ async def get_kb_file( raise ApiException(code=400, message=f"查询文件失败: {e}.") +class ReprocessFileRequest(BaseModel): + chunk_config: Optional[dict] = Field(default=None, description="Optional chunk configuration for the file") + + @knowledgebase_router.put("/{kb_id}/files/{file_id}") async def reprocess_file( kb_id: str, file_id: str, + body: Optional[ReprocessFileRequest] = Body(None), tenant_id: str = Depends(get_tenant_id), session: AsyncSession = Depends(get_db_session), file_service: FileService = Depends(get_file_service), + rag_service: RagService = Depends(get_rag_service), ): try: file_entities = await file_service.get_files_by_ids(kb_id=kb_id, file_ids=[file_id], tenant_id=tenant_id) if not file_entities: raise ApiException.not_found(file_id, "文件") - reprocessed_count = await _batch_reprocess_files(kb_id=kb_id, file_entities=file_entities, session=session, tenant_id=tenant_id) + # 如果提供了 chunk_config,先验证并更新 + chunk_config = None + if body and body.chunk_config: + try: + ChunkConfig.model_validate(body.chunk_config) + chunk_config = body.chunk_config + except Exception as e: + raise ApiException(code=400, message=f"chunk_config 格式错误: {e}") + + reprocessed_count = await _batch_reprocess_files( + kb_id=kb_id, + file_entities=file_entities, + session=session, + tenant_id=tenant_id, + chunk_config=chunk_config + ) return success_response(data=reprocessed_count, message=f"成功将 {reprocessed_count} 个文件加入重新处理队列。") + except ApiException: + raise except ValueError as e: logger.error(f"重新处理文件失败。\nValueError:{e}") raise ApiException(code=400, message=str(e)) @@ -369,10 +454,10 @@ async def delete_file( raise ApiException(code=400, message=f"删除文件失败: {e}.") - class BatchOperationRequest(BaseModel): operation: str = Field(..., description="操作类型: 'delete' 或 'reprocess'") file_id_list: List[str] = Field(..., description="要操作的文件ID列表") + chunk_config: Optional[dict] = Field(default=None, description="Optional chunk configuration for reprocess operation, shared by all files") @knowledgebase_router.post("/{kb_id}/files/batch", response_model=ResponseModel[dict]) @@ -422,8 +507,24 @@ async def batch_operations( raise ApiException(code=400, message=f"删除文件失败: {e}.") elif request.operation == "reprocess": try: - reprocessed_count = await _batch_reprocess_files(kb_id=kb_id, file_entities=file_entities, session=session, tenant_id=tenant_id) + chunk_config = None + if request.chunk_config: + try: + ChunkConfig.model_validate(request.chunk_config) + chunk_config = request.chunk_config + except Exception as e: + raise ApiException(code=400, message=f"chunk_config 格式错误: {e}") + + reprocessed_count = await _batch_reprocess_files( + kb_id=kb_id, + file_entities=file_entities, + session=session, + tenant_id=tenant_id, + chunk_config=chunk_config + ) return success_response(data=reprocessed_count, message=f"成功将 {reprocessed_count} 个文件加入重新处理队列。") + except ApiException: + raise except ValueError as e: logger.error(f"重新处理文件失败。\nValueError:{e}") raise ApiException(code=400, message=str(e)) @@ -437,9 +538,11 @@ async def _batch_reprocess_files( file_entities: List[KbFileEntity], session: AsyncSession, tenant_id: str, + chunk_config: Optional[dict] = None, ) -> ResponseModel[dict]: """ 批量重新处理文件的内部实现 + 如果提供了 chunk_config,会在重新解析之前更新所有文件的 chunk_config """ import app.worker as background_worker @@ -450,6 +553,12 @@ async def _batch_reprocess_files( file_entity.status = FileStatus.pending file_entity.file_version = file_version file_entity.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + + # 如果提供了 chunk_config,更新文件的 chunk_config + if chunk_config: + file_entity.chunk_config = chunk_config + logger.info(f"Updated chunk_config for file {file_entity.id} before reprocessing.") + session.add(file_entity) reprocessed_count += 1 @@ -474,7 +583,6 @@ class FileSourceParam(BaseModel): file_source: str = Field(default=None) - @knowledgebase_router.post("/{kb_id}/files/{file_id}/source", response_model=ResponseModel[KbFileEntity]) async def set_file_source( kb_id: str, diff --git a/backend/api/v1/faq_retrieval.py b/backend/api/v1/faq_retrieval.py new file mode 100644 index 000000000..c00c089e3 --- /dev/null +++ b/backend/api/v1/faq_retrieval.py @@ -0,0 +1,124 @@ +from fastapi import APIRouter, Depends +from pydantic import BaseModel +from typing import Optional, List +from common.chat.response_model import ResponseModel, success_response +from api.api_exception import ApiException +from db.db_context import get_db_session +from common.chat.models import DocRecord, NewRetrievalResponse, RetrievalSetting, MetadataFilteringCondition +from sqlmodel.ext.asyncio.session import AsyncSession +from service.injection import get_rag_service, get_tenant_id, get_chatapp_service, get_faq_config_service +from service.knowledgebase.rag_service import RagService +from service.tool.chatapp_service import ChatappService +from service.tool.faq_config_service import FAQConfigService +from service.knowledgebase.knowledgebase_service import KnowledgebaseService +from common.knowledgebase.constants import DEFAULT_FAQ_SIMILARITY_THRESHOLD +from common.knowledgebase.types import VectorIndexRetrievalType +from common.tool.search_result import SearchResult +import traceback +from loguru import logger + + +faq_retrieval_router = APIRouter() + + +class FAQRetrievalRequest(BaseModel): + chatapp_id: str # chatbot.id + query: str # 查询内容 + user_id: Optional[str] = None + retrieval_setting: Optional[RetrievalSetting] = None + metadata_condition: Optional[MetadataFilteringCondition] = None + + +@faq_retrieval_router.post( + "", response_model=ResponseModel[NewRetrievalResponse] +) +async def faq_retrieval( + retrieval_request: FAQRetrievalRequest, + session: AsyncSession = Depends(get_db_session), + tenant_id: str = Depends(get_tenant_id), + rag_service: RagService = Depends(get_rag_service), + chatapp_service: ChatappService = Depends(get_chatapp_service), + faq_config_service: FAQConfigService = Depends(get_faq_config_service), +): + logger.info(f"FAQ Retrieval request: chatapp_id={retrieval_request.chatapp_id}, query={retrieval_request.query}, tenant_id={tenant_id}") + try: + chatbot = await chatapp_service.get_chatapp_by_app_id( + app_id=retrieval_request.chatapp_id, + tenant_id=tenant_id + ) + if not chatbot: + raise ApiException(code=404, message=f"应用 '{retrieval_request.chatapp_id}' 不存在。") + + # Get FAQ config to get similarity_threshold + faq_config = chatbot.faq_config + if not faq_config: + raise ApiException(code=404, message=f"FAQ配置 '{retrieval_request.chatapp_id}' 不存在。") + + + kb_id = faq_config.kb_id + knowledgebase_service = KnowledgebaseService(session) + kb = await knowledgebase_service.get_knowledgebase(kb_id, tenant_id=tenant_id) + + if not kb: + raise ApiException(code=404, message=f"FAQ知识库 '{kb_id}' 不存在。") + + # Set default retrieval_setting if not provided, or merge with defaults + default_similarity_threshold = faq_config.similarity_threshold if faq_config else DEFAULT_FAQ_SIMILARITY_THRESHOLD + + if retrieval_request.retrieval_setting is None: + retrieval_setting = RetrievalSetting( + retrieval_mode=VectorIndexRetrievalType.vector, + top_k=1, + enable_rerank=False, + rerank_top_k=None, + vector_weight=1.0, + similarity_threshold=default_similarity_threshold, + ) + else: + # Merge user-provided settings with defaults + retrieval_setting = RetrievalSetting( + retrieval_mode=retrieval_request.retrieval_setting.retrieval_mode or VectorIndexRetrievalType.vector, + top_k=retrieval_request.retrieval_setting.top_k if retrieval_request.retrieval_setting.top_k is not None else 1, + enable_rerank=retrieval_request.retrieval_setting.enable_rerank if retrieval_request.retrieval_setting.enable_rerank is not None else False, + rerank_top_k=retrieval_request.retrieval_setting.rerank_top_k, + rerank_model=retrieval_request.retrieval_setting.rerank_model, + rerank_provider_name=retrieval_request.retrieval_setting.rerank_provider_name, + vector_weight=retrieval_request.retrieval_setting.vector_weight if retrieval_request.retrieval_setting.vector_weight is not None else 1.0, + similarity_threshold=retrieval_request.retrieval_setting.similarity_threshold if retrieval_request.retrieval_setting.similarity_threshold is not None else default_similarity_threshold, + score_threshold=retrieval_request.retrieval_setting.score_threshold, + ) + + search_results: List[SearchResult] = await rag_service.aquery( + query=retrieval_request.query, + user_id=retrieval_request.user_id, + kb_id=kb.id, + kb_id_list=None, + retrieval_setting=retrieval_setting, + metadata_condition=None, + tenant_id=tenant_id, + ) + + logger.info( + f"Retrieved {len(search_results)} FAQ results for query '{retrieval_request.query}' from knowledgebase {kb.id}." + ) + + records = [] + for node in search_results: + records.append(DocRecord( + content=node.content, + score=node.score, + title=node.title, + metadata=node.metadata, + )) + + # 使用统一的响应格式 + retrieval_response = NewRetrievalResponse(records=records) + return success_response(data=retrieval_response, message="FAQ检索成功") + except ApiException: + raise + except ValueError as e: + logger.error(f"Failed to retrieve FAQ: {traceback.format_exc()}") + raise ApiException(code=400, message=f"FAQ检索失败: {e}") + except Exception as e: + logger.error(f"Failed to retrieve FAQ: {traceback.format_exc()}") + raise ApiException(code=500, message=f"FAQ检索失败: {e}") diff --git a/backend/api/v1/routers.py b/backend/api/v1/routers.py index 8cfba1602..a826a7b30 100644 --- a/backend/api/v1/routers.py +++ b/backend/api/v1/routers.py @@ -44,6 +44,7 @@ def add_chat_router(app: FastAPI): from api.v1.chat import chat_agent_router from api.v1.thread import thread_router from api.v1.retrieval import retrieval_router + from api.v1.faq_retrieval import faq_retrieval_router from api.v1.retrieval_tool_api import retrieval_tool_router from api.v1.healthcheck import health_router from api.v1.embed import embedding_router @@ -51,6 +52,7 @@ def add_chat_router(app: FastAPI): app.include_router(chat_agent_router, prefix="/v1/chat/completions") app.include_router(thread_router, prefix="/v1/threads") app.include_router(retrieval_router, prefix="/v1/retrieval") + app.include_router(faq_retrieval_router, prefix="/v1/faq-retrieval") app.include_router(retrieval_tool_router, prefix="/v1/tools/retrieval") app.include_router(embedding_router, prefix="/v1/embeddings") diff --git a/backend/common/chat/models.py b/backend/common/chat/models.py index e6d4ee918..e0920f5d7 100644 --- a/backend/common/chat/models.py +++ b/backend/common/chat/models.py @@ -94,6 +94,7 @@ class ChatAgentRequest(BaseModel): mcp_ids: Optional[List[str]] = [] kb_ids: Optional[List[str]] = [] + faq_config: Optional[dict] = None enable_search: Optional[bool] = False enable_agent: Optional[bool] = False enable_chatdb: Optional[bool] = False diff --git a/backend/common/knowledgebase/constants.py b/backend/common/knowledgebase/constants.py index 90b023bf8..be0482dca 100644 --- a/backend/common/knowledgebase/constants.py +++ b/backend/common/knowledgebase/constants.py @@ -1,7 +1,7 @@ DEFAULT_CHUNK_SIZE = 1024 DEFAULT_CHUNK_OVERLAP = 50 DEFAULT_PARSER_TYPE = "structure" -DEFAULT_SENTENCE_SEPARATOR = "\n\n" +DEFAULT_PARAGRAPH_SEPARATOR = "\n\n" DEFAULT_EMBEDDING_MODEL = "BAAI/bge-m3" DEFAULT_RERANK_MODEL = "BAAI/bge-reranker-v2-m3" @@ -14,10 +14,13 @@ DEFAULT_VECTOR_WEIGHT = 0.5 +DEFAULT_FAQ_SIMILARITY_THRESHOLD = 0.8 + DEFAULT_KNOWLEDGEBASE_PATH = "localdata/knowledgebases" ATTACHMENT_KNOWLEDGEBASE_NAME = "default_attachments" +FAQ_KNOWLEDGEBASE_NAME = "default_faqs" DEFAULT_VECTOR_ID = "default_vectordb" diff --git a/backend/db/models/__init__.py b/backend/db/models/__init__.py index 67b5578d7..eca10be6e 100644 --- a/backend/db/models/__init__.py +++ b/backend/db/models/__init__.py @@ -14,6 +14,7 @@ from db.models.knowledgebase.metadata import KbMetadataEntity, FileMetadataEntity from db.models.prompt import PromptModelEntity from db.models.chatbot import ChatBotEntity +from db.models.faq_item import FAQItemEntity from db.models.guardrail import GuardrailConfigEntity from db.models.code_sandbox import CodeSandboxConfigEntity from db.models.evaluation.dataset import DatasetEntity, DatasetSampleEntity diff --git a/backend/db/models/chatbot.py b/backend/db/models/chatbot.py index 0ece8e1f5..c83f31c07 100644 --- a/backend/db/models/chatbot.py +++ b/backend/db/models/chatbot.py @@ -21,6 +21,8 @@ class ChatBotCreate(SQLModel): enable_input_guardrail: Optional[bool] = Field(default=False) enable_output_guardrail: Optional[bool] = Field(default=False) guardrail_hint: Optional[str] = Field(default=None, sa_column=Column(Text)) + enable_faq: Optional[bool] = Field(default=False) + faq_config: Optional[dict] = Field(default=None, sa_column=Column(JSON)) prompts: Optional[dict] = Field(default={}) @@ -45,6 +47,7 @@ class ChatBotEntity(ChatBotCreate, table=True): id: str = Field(primary_key=True, default_factory=lambda: str(uuid.uuid4().hex)) mcp_ids: List[str] = Field(default_factory=list, sa_column=Column(JSON)) kb_ids: List[str] = Field(default_factory=list, sa_column=Column(JSON)) + faq_config: Optional[dict] = Field(default=None, sa_column=Column(JSON)) prompts: Optional[dict] = Field(default_factory=dict, sa_column=Column(JSON)) tenant_id: Optional[str] = Field(default=DEFAULT_TENANT_ID) diff --git a/backend/db/models/faq_config.py b/backend/db/models/faq_config.py new file mode 100644 index 000000000..b7db0c218 --- /dev/null +++ b/backend/db/models/faq_config.py @@ -0,0 +1,18 @@ +from sqlmodel import Field, SQLModel +from typing import Optional +from common.knowledgebase.constants import DEFAULT_FAQ_SIMILARITY_THRESHOLD + + +class FAQConfigCreate(SQLModel): + active: bool = Field( + default=True + ) + # FAQ configuration fields + kb_id: Optional[str] = Field(default=None, description="FAQ知识库ID") + similarity_threshold: Optional[float] = Field(default=DEFAULT_FAQ_SIMILARITY_THRESHOLD, description="相似度阈值,范围0.8-1.0") + embedding_model: Optional[str] = Field(default="BAAI/bge-m3", description="Embedding模型ID") + enable_question_in_retrieval: Optional[bool] = Field(default=True, description="问题是否参与检索") + enable_question_in_response: Optional[bool] = Field(default=True, description="问题是否参与回答") + enable_answer_in_retrieval: Optional[bool] = Field(default=False, description="答案是否参与检索") + enable_answer_in_response: Optional[bool] = Field(default=True, description="答案是否参与回答") + return_direct: Optional[bool] = Field(default=False, description="是否直接返回工具结果,不经过LLM加工") diff --git a/backend/db/models/faq_item.py b/backend/db/models/faq_item.py new file mode 100644 index 000000000..0b12224c9 --- /dev/null +++ b/backend/db/models/faq_item.py @@ -0,0 +1,31 @@ +import uuid +from sqlmodel import Field, SQLModel +from sqlalchemy import Column, DateTime, Text +from datetime import datetime, timezone +from typing import Optional +from common.system_constants import DEFAULT_TENANT_ID + + +class FAQItemCreate(SQLModel): + question: str = Field(default=None, sa_column=Column(Text)) + answer: str = Field(default=None, sa_column=Column(Text)) + chatbot_id: str = Field(default=None) + file_id: Optional[str] = Field(default=None) + active: bool = Field(default=True) + + +class FAQItemEntity(FAQItemCreate, table=True): + __tablename__ = "pai_chatbot_faq" + + id: str = Field(primary_key=True, default_factory=lambda: str(uuid.uuid4().hex)) + tenant_id: Optional[str] = Field(default=DEFAULT_TENANT_ID, index=True) + chatbot_id: str = Field(default=None, foreign_key="pai_chatbot_model.app_id", ondelete="CASCADE", index=True) + + created_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + sa_column=Column(DateTime), + ) + updated_at: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + sa_column=Column(DateTime), + ) diff --git a/backend/db/models/knowledgebase/file.py b/backend/db/models/knowledgebase/file.py index bfd0727bc..3fdb40d71 100644 --- a/backend/db/models/knowledgebase/file.py +++ b/backend/db/models/knowledgebase/file.py @@ -50,6 +50,9 @@ class KbFileEntity(SQLModel, table=True): file_metadata: dict = Field(default={}, sa_column=Column("file_metadata", JSON)) + chunk_config: dict | None = Field( + default=None, sa_column=Column("chunk_config", JSON) + ) @field_serializer("created_at", "updated_at") def serialize_dt(self, dt: datetime, _info): diff --git a/backend/db/models/knowledgebase/knowledgebase.py b/backend/db/models/knowledgebase/knowledgebase.py index 8abaacdaa..e177ad4b1 100644 --- a/backend/db/models/knowledgebase/knowledgebase.py +++ b/backend/db/models/knowledgebase/knowledgebase.py @@ -7,7 +7,7 @@ from common.knowledgebase.constants import ( DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, - DEFAULT_SENTENCE_SEPARATOR, + DEFAULT_PARAGRAPH_SEPARATOR, DEFAULT_PARSER_TYPE, DEFAULT_EMBEDDING_MODEL, DEFAULT_SIMILARITY_TOP_K, @@ -15,17 +15,31 @@ DEFAULT_VECTOR_WEIGHT, ) from common.knowledgebase.types import VectorIndexRetrievalType -from typing import Optional +from typing import Optional, List from common.system_constants import DEFAULT_TENANT_ID from pydantic import field_serializer + +class TableParserConfig(SQLModel): + """Configuration for table parser (CSV/Excel). Only used when parser_type == 'table'.""" + concat_rows: bool = Field(default=False, description="Whether to concatenate all rows into one document") + row_joiner: str = Field(default="\n", description="Separator to use for joining each row") + header_index_max: Optional[int] = Field(default=0, description="Maximum row index to use as header") + format_sheet_data_to_json: bool = Field(default=False, description="Whether to format sheet data as JSON") + sheet_column_filters: Optional[List[str]] = Field(default=None, description="List of column names to filter") + question_column_index: Optional[int] = Field(default=0, description="Index of question column") + answer_column_index: Optional[int] = Field(default=1, description="Index of answer column") + + + class ChunkConfig(SQLModel): chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE) chunk_overlap: int = Field(default=DEFAULT_CHUNK_OVERLAP) parser_type: str = Field(default=DEFAULT_PARSER_TYPE) - separator: str = Field(default=DEFAULT_SENTENCE_SEPARATOR) + separator: str = Field(default=DEFAULT_PARAGRAPH_SEPARATOR) image_caption_model: Optional[str] = Field(default=None) image_caption_provider_name: str = Field(default="openai_like") + table_config: Optional[TableParserConfig] = Field(default=None, description="Table parser configuration (only used when parser_type == 'table')") class RetrievalConfig(SQLModel): diff --git a/backend/rag/kb_file_client.py b/backend/rag/kb_file_client.py index a899b98ca..c61f842d7 100644 --- a/backend/rag/kb_file_client.py +++ b/backend/rag/kb_file_client.py @@ -30,13 +30,17 @@ from pairag.file.store.file_store_helper import file_store from loguru import logger from rag.parse_utils import sanitize_text, get_node_texts_for_embedding -from common.knowledgebase.constants import DEFAULT_SENTENCE_SEPARATOR +from common.knowledgebase.constants import DEFAULT_PARAGRAPH_SEPARATOR class KbFileClient: - async def create_file_parser(self, knowledgebase: KbEntity, image_caption_tool: Optional[ImageCaptionTool] = None): - chunk_config = ChunkConfig.model_validate(knowledgebase.chunk_config) + async def create_file_parser(self, knowledgebase: KbEntity, file_entity: Optional[KbFileEntity] = None, image_caption_tool: Optional[ImageCaptionTool] = None): + # Prioritize file's chunk_config if available, otherwise use knowledgebase's chunk_config + if file_entity and file_entity.chunk_config: + chunk_config = ChunkConfig.model_validate(file_entity.chunk_config) + else: + chunk_config = ChunkConfig.model_validate(knowledgebase.chunk_config) image_caption_tool = None if chunk_config.image_caption_model: @@ -48,7 +52,7 @@ async def create_file_parser(self, knowledgebase: KbEntity, image_caption_tool: image_caption_tool = ImageCaptionTool(multimodal_llm=multimodal_llm) if not chunk_config.separator: - chunk_config.separator = DEFAULT_SENTENCE_SEPARATOR + chunk_config.separator = DEFAULT_PARAGRAPH_SEPARATOR file_parser = FileParser( file_store=file_store, @@ -151,7 +155,7 @@ async def process_file_async( return # parsing file logger.info(f"Parsing file {file_item.file_name}.") - file_parser = await self.create_file_parser(knowledgebase) + file_parser = await self.create_file_parser(knowledgebase, file_entity=file_entity) documents, nodes = file_parser.parse(file_item, is_attachment=is_attachment) await update_file_content_async(file_id=file_item.id, is_attachment=is_attachment, documents=documents, tenant_id=tenant_id) for node in nodes: diff --git a/backend/rag/parse_utils.py b/backend/rag/parse_utils.py index 5bed6eb18..b0ca4e7f1 100644 --- a/backend/rag/parse_utils.py +++ b/backend/rag/parse_utils.py @@ -28,8 +28,11 @@ def sanitize_text(text: str) -> str: def get_node_texts_for_embedding(nodes) -> list[str]: texts = [] for node in nodes: - base_text = f"filename: {node.metadata['file_name']}" + base_text = "" + file_name = node.metadata.get('file_name', '').strip() chapter_name = node.metadata.get('chapter_name', '').strip() + if file_name: + base_text += f"file_name: {file_name}" if chapter_name: base_text += f"\n\nchapter_name: {chapter_name}" diff --git a/backend/service/agent/agent_service.py b/backend/service/agent/agent_service.py index fb1b9b4a1..3c7ba0b9f 100644 --- a/backend/service/agent/agent_service.py +++ b/backend/service/agent/agent_service.py @@ -1,6 +1,7 @@ from common.chat.models import ChatAgentRequest from service.factory.model_factory import create_llm from tools.knowledgebase.knowledgebase_tool import aget_knowledgebase_tool +from tools.knowledgebase.faq_tool import aget_faq_tool from service.factory.tools import create_search_tools, create_chatdb_tools, create_codesandbox_tools from service.factory.mcp_factory import create_mcp_tools_async from tools.attachments.file_searcher import aget_file_searcher @@ -42,10 +43,12 @@ def __init__( chatdb_service_getter: Callable[[], Awaitable], rag_service_getter: Callable[[], Awaitable], file_service_getter: Callable[[], Awaitable], + faq_config_service_getter: Callable[[], Awaitable], ): self.session = session self._get_llm_service = llm_service_getter self._get_chatapp_service = chatapp_service_getter + self._get_faq_config_service = faq_config_service_getter self._get_websearch_service = websearch_service_getter self._get_codesandbox_service = codesandbox_service_getter self._get_chatdb_service = chatdb_service_getter @@ -59,6 +62,7 @@ async def create_agent(self, chat_request: ChatAgentRequest, tenant_id: str) -> llm_model = await llm_service.get_llm_by_model_id(chat_request.model, tenant_id=tenant_id) prompt_set = PlanAgentPromptSet() + chatapp_id = None if llm_model: llm = create_llm(llm_model) @@ -68,8 +72,10 @@ async def create_agent(self, chat_request: ChatAgentRequest, tenant_id: str) -> if not chatapp: raise ValueError(f"Model `{chat_request.model}` not found.") + chatapp_id = chatapp.app_id chat_request.model = chatapp.model_id chat_request.mcp_ids = chatapp.mcp_ids + chat_request.faq_config = chatapp.faq_config chat_request.kb_ids = chatapp.kb_ids chat_request.enable_search = chatapp.enable_search chat_request.enable_chatdb = chatapp.enable_chatdb @@ -96,7 +102,9 @@ async def create_agent(self, chat_request: ChatAgentRequest, tenant_id: str) -> enable_chatdb=chat_request.enable_chatdb, mcp_ids=chat_request.mcp_ids, kb_ids=chat_request.kb_ids, + faq_config=chat_request.faq_config, user_id=chat_request.user_id, + chatapp_id=chatapp_id, metadata_condition=chat_request.metadata_condition, tenant_id=tenant_id, ) @@ -125,15 +133,24 @@ async def aget_tools( mcp_ids: List[str] = [], kb_ids: List[str] = [], tenant_id: str = None, + chatapp_id: Optional[str] = None, + faq_config: Optional[dict] = None, ) -> tuple[List[FunctionTool], Callable | None]: tools = [] # 知识库工具 rag_service = await self._get_rag_service() + chatapp_service = await self._get_chatapp_service() + faq_config_service = await self._get_faq_config_service() for kb_id in kb_ids: tools.append(await aget_knowledgebase_tool(kb_id=kb_id, user_id=user_id, rag_service=rag_service, tenant_id=tenant_id, metadata_condition=metadata_condition)) logger.info(f"Resolved {len(kb_ids)} knowledgebase tools.") + # FAQ工具 + if faq_config and faq_config.get("active"): + tools.append(await aget_faq_tool(chatapp_id=chatapp_id, user_id=user_id, rag_service=rag_service, chatapp_service=chatapp_service, faq_config_service=faq_config_service, tenant_id=tenant_id)) + logger.info("Resolved FAQ tool.") + # 搜索工具 if enable_search: # Add search web tool diff --git a/backend/service/factory/tools.py b/backend/service/factory/tools.py index 739f124bd..0a0778d1c 100644 --- a/backend/service/factory/tools.py +++ b/backend/service/factory/tools.py @@ -53,12 +53,14 @@ async def aget_search_result(query: str) -> str: async_fn=aget_search_result, name="tavily-websearch", description="从 Tavily 搜索引擎中搜索给定查询的最新内容。", + return_direct=False, ) else: search_tool = FunctionTool.from_defaults( async_fn=aget_search_result, name="aliyun-websearch", description="从阿里云搜索引擎中搜索给定查询的最新内容。", + return_direct=False, ) search_cache.put(search_key, search_tool) @@ -85,6 +87,7 @@ def create_chatdb_tools(chatdb_config: ChatDbConfigEntity, chatdb_llm: PaiLlm) - async_fn=chatdb_client.execute_async, name="chat-db", description="使用自然语言从给定的数据库中获取数据。输入参数: query(str类型),表示用户的查询意图,需结合上下文信息生成。", + return_direct=False, ) chatdb_cache.put(chatdb_key, chatdb_tool) return [chatdb_tool] @@ -156,6 +159,7 @@ async def aexecute_code( # Returns - A string containing all printed output from execution, including Markdown image references if plots are generated. """, + return_direct=False, ) # 安装包的工具 @@ -199,6 +203,7 @@ async def ainstall_package( - Install with version: `{"package_name": "pandas==1.5.0"}` - Install multiple packages: `{"package_name": "numpy scipy matplotlib"}` """, + return_direct=False, ) # 清理函数 diff --git a/backend/service/injection.py b/backend/service/injection.py index 751426a13..90210b259 100644 --- a/backend/service/injection.py +++ b/backend/service/injection.py @@ -10,6 +10,8 @@ from service.tool.websearch_service import WebsearchService from service.tool.chatdb_service import ChatdbService from service.tool.chatapp_service import ChatappService +from service.tool.faq_config_service import FAQConfigService +from service.tool.faq_item_service import FAQItemService from service.tool.codesandbox_service import CodesandboxService from service.tool.evaluation_service import EvaluationService from service.tool.guardrail_service import GuardrailService @@ -438,6 +440,36 @@ async def get_chatapp_service( return ChatappService(session) +async def get_faq_config_service( + session: AsyncSession = Depends(get_db_session), +) -> FAQConfigService: + """ + FastAPI dependency injection function for FAQConfigService. + + Args: + session: Database session (injected via Depends) + + Returns: + FAQConfigService instance with the injected session + """ + return FAQConfigService(session) + + +async def get_faq_item_service( + session: AsyncSession = Depends(get_db_session), +) -> FAQItemService: + """ + FastAPI dependency injection function for FAQItemService. + + Args: + session: Database session (injected via Depends) + + Returns: + FAQItemService instance with the injected session + """ + return FAQItemService(session) + + async def get_codesandbox_service( session: AsyncSession = Depends(get_db_session), ) -> CodesandboxService: @@ -613,6 +645,9 @@ async def rag_service_getter(): async def file_service_getter(): return await get_file_service(session) + async def faq_config_service_getter(): + return await get_faq_config_service(session) + return AgentService( session=session, llm_service_getter=llm_service_getter, @@ -623,6 +658,7 @@ async def file_service_getter(): chatdb_service_getter=chatdb_service_getter, rag_service_getter=rag_service_getter, file_service_getter=file_service_getter, + faq_config_service_getter=faq_config_service_getter, ) diff --git a/backend/service/knowledgebase/knowledgebase_service.py b/backend/service/knowledgebase/knowledgebase_service.py index da2947c03..9cf730d36 100644 --- a/backend/service/knowledgebase/knowledgebase_service.py +++ b/backend/service/knowledgebase/knowledgebase_service.py @@ -17,6 +17,7 @@ from db.models.knowledgebase.file import KbFileEntity from common.chat.response_model import PagedResult from service.cache.redis_cache import cache_manager, kb_key, kb_name_key +from common.knowledgebase.constants import FAQ_KNOWLEDGEBASE_NAME class KnowledgebaseService: """Service layer for Knowledgebase entity CRUD operations using dependency injection.""" @@ -44,16 +45,22 @@ async def get_knowledgebase(self, kb_id: str, tenant_id: str) -> Optional[KbEnti KbEntity if found, None otherwise """ cache_key = kb_key(tenant_id, kb_id) - kb_data = await cache_manager.get_cache().get(cache_key) - if kb_data: - logger.info(f"Get knowledgebase entity from cache: {kb_id}") - kb_entity = KbEntity.model_validate(kb_data) - return kb_entity + try: + kb_data = await cache_manager.get_cache().get(cache_key) + if kb_data: + logger.info(f"Get knowledgebase entity from cache: {kb_id}") + kb_entity = KbEntity.model_validate(kb_data) + return kb_entity + except Exception as e: + logger.warning(f"Cache get operation failed for {cache_key}: {e}") result = await self.session.exec(select(KbEntity).where(KbEntity.id == kb_id, KbEntity.tenant_id == tenant_id)) kb_entity = result.first() if kb_entity: - await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) + try: + await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") return kb_entity async def get_knowledgebase_by_name(self, name: str, tenant_id: str) -> Optional[KbEntity]: @@ -67,17 +74,23 @@ async def get_knowledgebase_by_name(self, name: str, tenant_id: str) -> Optional KbEntity if found, None otherwise """ cache_key = kb_name_key(tenant_id, name) - kb_data = await cache_manager.get_cache().get(cache_key) - if kb_data: - logger.info(f"Get knowledgebase entity from cache: {name}") - kb_entity = KbEntity.model_validate(kb_data) - return kb_entity + try: + kb_data = await cache_manager.get_cache().get(cache_key) + if kb_data: + logger.info(f"Get knowledgebase entity from cache: {name}") + kb_entity = KbEntity.model_validate(kb_data) + return kb_entity + except Exception as e: + logger.warning(f"Cache get operation failed for {cache_key}: {e}, falling back to database") statement = select(KbEntity).where(KbEntity.name == name, KbEntity.tenant_id == tenant_id) result = await self.session.exec(statement) kb_entity = result.first() if kb_entity: - await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) + try: + await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") return kb_entity async def get_knowledgebases_by_ids(self, tenant_id: str, kb_ids: List[str]) -> List[KbEntity]: @@ -111,11 +124,14 @@ async def list_knowledgebases( Returns: PagedResult containing list of KbEntity with file_count and pagination metadata """ - # Build base query condition + conditions = [KbEntity.tenant_id == tenant_id] + if exclude_default_attachments: - base_condition = and_(KbEntity.name != "default_attachments", KbEntity.tenant_id == tenant_id) - else: - base_condition = KbEntity.tenant_id == tenant_id + conditions.append(KbEntity.name != "default_attachments") + + conditions.append(~KbEntity.name.like(f"%_{FAQ_KNOWLEDGEBASE_NAME}")) + + base_condition = and_(*conditions) # Add search condition if provided if query: query_lower = query.lower() @@ -254,7 +270,10 @@ async def update_knowledgebase( ValueError: If Knowledgebase entity not found """ cache_key = kb_key(tenant_id, kb_id) - await cache_manager.get_cache().delete(cache_key) + try: + await cache_manager.get_cache().delete(cache_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_key}: {e}") result = await self.session.exec(select(KbEntity).where(KbEntity.id == kb_id, KbEntity.tenant_id == tenant_id)) knowledgebase = result.first() @@ -262,7 +281,10 @@ async def update_knowledgebase( raise ValueError(f"知识库 '{kb_id}' 不存在。") cache_name_key = kb_name_key(tenant_id, knowledgebase.name) - await cache_manager.get_cache().delete(cache_name_key) + try: + await cache_manager.get_cache().delete(cache_name_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_name_key}: {e}") try: @@ -329,8 +351,14 @@ async def delete_knowledgebase(self, kb_id: str, tenant_id: str) -> None: # Delete both ID-based and name-based cache entries cache_key = kb_key(tenant_id, kb_id) cache_name_key = kb_name_key(tenant_id, knowledgebase.name) - await cache_manager.get_cache().delete(cache_key) - await cache_manager.get_cache().delete(cache_name_key) + try: + await cache_manager.get_cache().delete(cache_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_key}: {e}") + try: + await cache_manager.get_cache().delete(cache_name_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_name_key}: {e}") # Delete knowledgebase entity only await self.session.delete(knowledgebase) @@ -362,8 +390,14 @@ async def write_cache_after_commit(self, kb_entity: KbEntity, tenant_id: str) -> """ cache_key = kb_key(tenant_id, kb_entity.id) cache_name_key = kb_name_key(tenant_id, kb_entity.name) - await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) - await cache_manager.get_cache().set(cache_name_key, kb_entity.model_dump(mode="json")) + try: + await cache_manager.get_cache().set(cache_key, kb_entity.model_dump(mode="json")) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") + try: + await cache_manager.get_cache().set(cache_name_key, kb_entity.model_dump(mode="json")) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_name_key}: {e}") logger.info(f"Written cache for knowledgebase {kb_entity.id} (name: {kb_entity.name}) after commit") async def delete_cache_on_rollback(self, kb_id: str, tenant_id: str, kb_name: Optional[str] = None) -> None: @@ -377,8 +411,14 @@ async def delete_cache_on_rollback(self, kb_id: str, tenant_id: str, kb_name: Op kb_name: Optional knowledgebase name (if known) """ cache_key = kb_key(tenant_id, kb_id) - await cache_manager.get_cache().delete(cache_key) + try: + await cache_manager.get_cache().delete(cache_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_key}: {e}") if kb_name: cache_name_key = kb_name_key(tenant_id, kb_name) - await cache_manager.get_cache().delete(cache_name_key) + try: + await cache_manager.get_cache().delete(cache_name_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_name_key}: {e}") logger.info(f"Deleted cache for knowledgebase {kb_id} on rollback") diff --git a/backend/service/knowledgebase/vector_table_mapping_service.py b/backend/service/knowledgebase/vector_table_mapping_service.py index 1decfe0de..0924f8ed2 100644 --- a/backend/service/knowledgebase/vector_table_mapping_service.py +++ b/backend/service/knowledgebase/vector_table_mapping_service.py @@ -39,26 +39,49 @@ async def get_vector_table_name(self, tenant_id: str, kb_id: str) -> str: The vector table name """ cache_key = vector_table_name_key(tenant_id, kb_id) - table_name = await cache_manager.get_cache().get(cache_key) - if table_name: - logger.debug(f"Found vector table name in cache for tenant {tenant_id} and kb {kb_id}: {table_name}") - return table_name + try: + table_name = await cache_manager.get_cache().get(cache_key) + if table_name: + logger.debug(f"Found vector table name in cache for tenant {tenant_id} and kb {kb_id}: {table_name}") + return table_name + except Exception as e: + logger.warning(f"Cache get operation failed for {cache_key}: {e}, falling back to database") # Try to find existing mapping mapping = await self._get_mapping(tenant_id, kb_id) if mapping: logger.debug(f"Found existing vector table mapping: {mapping.table_name}") - await cache_manager.get_cache().set(cache_key, mapping.table_name) + try: + await cache_manager.get_cache().set(cache_key, mapping.table_name) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") return mapping.table_name # Generate new table name table_name = generate_vector_table_name(tenant_id, kb_id) logger.debug(f"Generated new vector table name: {table_name}") - # Store the mapping - await self._create_mapping(tenant_id, kb_id, table_name) - await cache_manager.get_cache().set(cache_key, table_name) + # Store the mapping (with retry logic for race conditions) + try: + mapping = await self._create_mapping(tenant_id, kb_id, table_name) + # Use the actual table_name from the mapping (in case we got an existing one) + table_name = mapping.table_name + except Exception as create_ex: + # If creation failed, try one more time to get existing mapping + # (another concurrent task might have created it) + logger.warning(f"Failed to create mapping, retrying to get existing: {create_ex}") + existing_mapping = await self._get_mapping(tenant_id, kb_id) + if existing_mapping: + logger.info(f"Found existing mapping after creation failure for tenant={tenant_id}, kb={kb_id}") + table_name = existing_mapping.table_name + else: + raise + + try: + await cache_manager.get_cache().set(cache_key, table_name) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") return table_name async def _get_mapping( @@ -86,6 +109,7 @@ async def _create_mapping( ) -> VectorTableMappingEntity: """ Create a new vector table mapping. + Note: This method does not commit the transaction. The caller is responsible for committing. Args: tenant_id: The tenant ID @@ -95,6 +119,12 @@ async def _create_mapping( Returns: The created VectorTableMappingEntity """ + # First check if mapping already exists (race condition protection) + existing_mapping = await self._get_mapping(tenant_id, kb_id) + if existing_mapping: + logger.debug(f"Mapping already exists for tenant={tenant_id}, kb={kb_id}, returning existing") + return existing_mapping + mapping = VectorTableMappingEntity( tenant_id=tenant_id, kb_id=kb_id, @@ -102,17 +132,44 @@ async def _create_mapping( ) try: self.session.add(mapping) - await self.session.commit() - await self.session.refresh(mapping) - logger.info(f"Created vector table mapping: tenant={tenant_id}, kb={kb_id}, table={table_name}") - return mapping + try: + await self.session.flush() + await self.session.refresh(mapping) + logger.info(f"Created vector table mapping: tenant={tenant_id}, kb={kb_id}, table={table_name}") + return mapping + except Exception as flush_ex: + # Check if it's a "Session is already flushing" error + error_str = str(flush_ex).lower() + if "already flushing" in error_str or "invalidrequesterror" in error_str: + logger.warning(f"Session flush conflict detected, trying to get existing mapping for tenant={tenant_id}, kb={kb_id}") + # Remove the failed mapping from session to avoid conflicts + try: + self.session.expunge(mapping) + except Exception: + pass + # Try to get existing mapping (another concurrent task may have created it) + existing_mapping = await self._get_mapping(tenant_id, kb_id) + if existing_mapping: + logger.info(f"Found existing vector table mapping after flush conflict for tenant={tenant_id}, kb={kb_id}") + return existing_mapping + # If no existing mapping, log and re-raise + logger.error("No existing mapping found after flush conflict, re-raising error") + raise + raise except Exception as ex: logger.error(f"Failed to create vector table mapping: {traceback.format_exc()}") - await self.session.rollback() - if "UniqueViolationError" in str(ex.orig) or "Duplicate entry" in str(ex.orig): - pass - else: - raise ValueError(f"Failed to create vector table mapping: {ex}") from ex + # Check if it's a unique constraint violation + error_msg = str(ex) + if hasattr(ex, 'orig'): + error_msg = str(ex.orig) + + if "UniqueViolationError" in error_msg or "Duplicate entry" in error_msg or "UNIQUE constraint" in error_msg: + # If duplicate, try to get the existing mapping + existing_mapping = await self._get_mapping(tenant_id, kb_id) + if existing_mapping: + logger.info(f"Found existing vector table mapping after unique constraint violation for tenant={tenant_id}, kb={kb_id}") + return existing_mapping + raise ValueError(f"Failed to create vector table mapping: {ex}") from ex async def delete_mapping(self, tenant_id: str, kb_id: str) -> bool: """ @@ -128,7 +185,10 @@ async def delete_mapping(self, tenant_id: str, kb_id: str) -> bool: mapping = await self._get_mapping(tenant_id, kb_id) if mapping: cache_key = vector_table_name_key(tenant_id, kb_id) - await cache_manager.get_cache().delete(cache_key) + try: + await cache_manager.get_cache().delete(cache_key) + except Exception as e: + logger.warning(f"Cache delete operation failed for {cache_key}: {e}") await self.session.delete(mapping) await self.session.commit() logger.info(f"Deleted vector table mapping: tenant={tenant_id}, kb={kb_id}") @@ -147,13 +207,19 @@ async def get_table_name_if_exists(self, tenant_id: str, kb_id: str) -> Optional The table name if mapping exists, None otherwise """ cache_key = vector_table_name_key(tenant_id, kb_id) - table_name = await cache_manager.get_cache().get(cache_key) - if table_name: - logger.debug(f"Found vector table name in cache for tenant {tenant_id} and kb {kb_id}: {table_name}") - return table_name + try: + table_name = await cache_manager.get_cache().get(cache_key) + if table_name: + logger.debug(f"Found vector table name in cache for tenant {tenant_id} and kb {kb_id}: {table_name}") + return table_name + except Exception as e: + logger.warning(f"Cache get operation failed for {cache_key}: {e}, falling back to database") mapping = await self._get_mapping(tenant_id, kb_id) if mapping: - await cache_manager.get_cache().set(cache_key, mapping.table_name) + try: + await cache_manager.get_cache().set(cache_key, mapping.table_name) + except Exception as e: + logger.warning(f"Cache set operation failed for {cache_key}: {e}") return mapping.table_name return None diff --git a/backend/service/tool/chatapp_service.py b/backend/service/tool/chatapp_service.py index 79940a994..e9deb8590 100644 --- a/backend/service/tool/chatapp_service.py +++ b/backend/service/tool/chatapp_service.py @@ -8,7 +8,16 @@ from loguru import logger from db.models.chatbot import ChatBotCreate, ChatBotEntity +from db.models.knowledgebase.knowledgebase import KnowledgebaseCreate, RetrievalConfig, ChunkConfig, TableParserConfig, KbEntity from common.chat.response_model import PagedResult +from common.knowledgebase.constants import FAQ_KNOWLEDGEBASE_NAME, DEFAULT_FAQ_SIMILARITY_THRESHOLD +from common.knowledgebase.types import VectorIndexRetrievalType +from service.knowledgebase.knowledgebase_service import KnowledgebaseService +from service.model.embedding_service import EmbeddingService +from service.tool.faq_config_service import FAQConfigService +from db.models.faq_config import FAQConfigCreate +from db.models.faq_item import FAQItemEntity +from service.knowledgebase.rag_service import RagService class ChatappService: @@ -23,6 +32,79 @@ def __init__(self, session: AsyncSession): """ self.session = session + async def _ensure_faq_knowledgebase(self, chatbot_id: str, app_id: str, tenant_id: str) -> KbEntity: + """ + Ensure FAQ knowledgebase exists for the given chatbot_id and app_id. + Creates it if it doesn't exist. + Uses embedding_model from faq_config if available, otherwise uses default. + + Args: + chatbot_id: ChatApp chatbot_id + app_id: ChatApp app_id + tenant_id: Tenant ID + + Returns: + KbEntity representing the FAQ knowledgebase + """ + kb_name = f"{app_id}_{FAQ_KNOWLEDGEBASE_NAME}" + knowledgebase_service = KnowledgebaseService(self.session) + embedding_service = EmbeddingService(self.session) + faq_config_service = FAQConfigService(self.session) + + knowledgebase = await knowledgebase_service.get_knowledgebase_by_name(kb_name, tenant_id=tenant_id) + + if not knowledgebase: + logger.info(f"Creating FAQ knowledgebase {kb_name} for app_id {app_id} and tenant {tenant_id}") + + # Get FAQ config to get embedding_model + faq_config = await faq_config_service.get_faq_config_by_chatbot_id( + chatbot_id=chatbot_id, tenant_id=tenant_id + ) + + # Use embedding_model from faq_config if available, otherwise use default + if faq_config and faq_config.embedding_model: + embedding_model = faq_config.embedding_model + logger.info(f"Using embedding_model {embedding_model} from FAQ config for knowledgebase {kb_name}") + else: + default_embedding_config = await embedding_service.get_default_embedding(tenant_id=tenant_id) + embedding_model = default_embedding_config.model_id + logger.info(f"Using default embedding_model {embedding_model} for knowledgebase {kb_name}") + + # Set default retrieval_config + default_similarity_threshold = faq_config.similarity_threshold if faq_config else DEFAULT_FAQ_SIMILARITY_THRESHOLD + + retrieval_config = RetrievalConfig( + retrieval_mode=VectorIndexRetrievalType.vector, + top_k=1, + enable_rerank=False, + rerank_top_k=None, + vector_weight=1.0, + similarity_threshold=default_similarity_threshold, + ) + + chunk_config = ChunkConfig( + table_config=TableParserConfig( + header_index_max=0, + question_column_index=0, + answer_column_index=1, + ), + parser_type="faq", + ) + + kb_create = KnowledgebaseCreate( + name=kb_name, + description="faq知识库", + embedding_model=embedding_model, + retrieval_config=retrieval_config, + chunk_config=chunk_config, + ) + knowledgebase = await knowledgebase_service.create_knowledgebase(kb_data=kb_create, tenant_id=tenant_id) + await self.session.flush() + await self.session.refresh(knowledgebase) + logger.info(f"Created FAQ knowledgebase {knowledgebase.id} (name: {kb_name}) for app_id {app_id}") + + return knowledgebase + async def get_chatapp(self, id: str, tenant_id: str) -> Optional[ChatBotEntity]: """ Get a single ChatApp entity by ID. @@ -104,13 +186,19 @@ async def create_chatapp(self, app_data: ChatBotCreate, tenant_id: str) -> ChatB Args: app_data: ChatApp creation data + tenant_id: Tenant ID Returns: Created ChatBotEntity (not yet committed) Raises: - ValueError: If app_id already exists (IntegrityError converted) + ValueError: If app_id already exists """ + # Check if app_id already exists + existing_chatbot = await self.get_chatapp_by_app_id(app_id=app_data.app_id, tenant_id=tenant_id) + if existing_chatbot: + raise ValueError(f"应用ID '{app_data.app_id}' 已经存在,无法创建。") + chatbot = ChatBotEntity.model_validate(app_data, update={"tenant_id": tenant_id}) self.session.add(chatbot) @@ -119,6 +207,33 @@ async def create_chatapp(self, app_data: ChatBotCreate, tenant_id: str) -> ChatB await self.session.flush() await self.session.refresh(chatbot) + # If enable_faq is True, create FAQ config + if app_data.enable_faq: + # Ensure FAQ knowledgebase exists first (to get kb_id) + knowledgebase = await self._ensure_faq_knowledgebase(chatbot.id, chatbot.app_id, tenant_id) + + # Initialize FAQ config with default values and set kb_id + faq_config_service = FAQConfigService(self.session) + faq_config = await faq_config_service.get_or_create_faq_config( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + + # Update faq_config with kb_id + if not faq_config.kb_id: + faq_config.kb_id = knowledgebase.id + await faq_config_service.update_faq_config( + chatbot_id=chatbot.id, + update_data=faq_config, + tenant_id=tenant_id + ) + + await self.session.flush() + await self.session.refresh(chatbot) + + logger.info( + f"Created FAQ config for ChatApp: {chatbot.id} (app_id: {chatbot.app_id})" + ) + logger.info( f"Created ChatApp entity: {chatbot.id} (app_id: {chatbot.app_id})" ) @@ -144,12 +259,13 @@ async def update_chatapp( Args: id: ChatApp entity ID update_data: Updated ChatApp data + tenant_id: Tenant ID Returns: Updated ChatBotEntity (not yet committed) Raises: - ValueError: If ChatApp entity not found + ValueError: If ChatApp entity not found or app_id already exists """ chatbot = await self.get_chatapp(id=id, tenant_id=tenant_id) if not chatbot: @@ -157,6 +273,54 @@ async def update_chatapp( logger.info(f"Updating ChatApp {id} with data: {update_data}") + # Check if app_id is being updated and if it conflicts with existing records + if update_data.app_id is not None and update_data.app_id != chatbot.app_id: + existing_chatbot = await self.get_chatapp_by_app_id(app_id=update_data.app_id, tenant_id=tenant_id) + if existing_chatbot and existing_chatbot.id != id: + raise ValueError(f"应用ID '{update_data.app_id}' 已经存在,无法更新。") + + faq_config_service = FAQConfigService(self.session) + if update_data.enable_faq: + if not chatbot.faq_config: + knowledgebase = await self._ensure_faq_knowledgebase(chatbot.id, chatbot.app_id, tenant_id) + + faq_config = await faq_config_service.get_or_create_faq_config( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + + # Update faq_config with kb_id + if not faq_config.kb_id: + faq_config.kb_id = knowledgebase.id + await faq_config_service.update_faq_config( + chatbot_id=chatbot.id, + update_data=faq_config, + tenant_id=tenant_id + ) + + logger.info( + f"Created FAQ config for ChatApp: {chatbot.id}" + ) + else: + current_config = FAQConfigCreate.model_validate(chatbot.faq_config) + if not current_config.active: + current_config.active = True + await faq_config_service.update_faq_config( + chatbot_id=chatbot.id, + update_data=current_config, + tenant_id=tenant_id + ) + logger.info(f"Updated FAQ config active to True for ChatApp: {chatbot.id}") + else: + current_config = FAQConfigCreate.model_validate(chatbot.faq_config) + + current_config.active = False + await faq_config_service.update_faq_config( + chatbot_id=chatbot.id, + update_data=current_config, + tenant_id=tenant_id + ) + logger.info(f"Disabled FAQ for ChatApp: {chatbot.id}") + # Update fields if update_data.app_id is not None: chatbot.app_id = update_data.app_id @@ -184,6 +348,8 @@ async def update_chatapp( chatbot.guardrail_hint = update_data.guardrail_hint if update_data.prompts is not None: chatbot.prompts = update_data.prompts + if update_data.enable_faq is not None: + chatbot.enable_faq = update_data.enable_faq chatbot.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) self.session.add(chatbot) @@ -195,13 +361,15 @@ async def update_chatapp( logger.info(f"Updated ChatApp entity: {chatbot.id} (app_id: {chatbot.app_id})") return chatbot - async def delete_chatapp(self, id: str, tenant_id: str) -> None: + async def delete_chatapp(self, id: str, tenant_id: str, rag_service: Optional[RagService] = None) -> None: """ Delete a ChatApp entity. Note: Caller is responsible for committing the session. Args: id: ChatApp entity ID + tenant_id: Tenant ID + rag_service: Optional RagService for deleting FAQ knowledgebase Raises: ValueError: If ChatApp entity not found @@ -210,7 +378,38 @@ async def delete_chatapp(self, id: str, tenant_id: str) -> None: if not chatbot: raise ValueError(f"应用 '{id}' 不存在。") + # Delete FAQ knowledgebase if exists + if chatbot.faq_config: + try: + faq_config = FAQConfigCreate.model_validate(chatbot.faq_config) + if faq_config.kb_id and rag_service: + try: + await rag_service.delete_knowledgebase(kb_id=faq_config.kb_id, tenant_id=tenant_id) + logger.info(f"Deleted FAQ knowledgebase {faq_config.kb_id} for ChatApp: {id}") + except Exception as e: + logger.warning(f"Failed to delete FAQ knowledgebase {faq_config.kb_id}: {e}") + # Continue with chatbot deletion even if KB deletion fails + except Exception as e: + logger.warning(f"Failed to parse FAQ config for chatbot {id}: {e}") + # Continue with chatbot deletion even if FAQ config parsing fails + + try: + faq_items = await self.session.exec( + select(FAQItemEntity).where( + FAQItemEntity.chatbot_id == chatbot.app_id, + FAQItemEntity.tenant_id == tenant_id + ) + ) + faq_items_list = list(faq_items.all()) + if faq_items_list: + for faq_item in faq_items_list: + await self.session.delete(faq_item) + logger.info(f"Deleted {len(faq_items_list)} FAQ items for ChatApp: {id}") + except Exception as e: + logger.warning(f"Failed to delete FAQ items for chatbot {id}: {e}") + # Delete from database (staged, not committed) + # FAQ items should be automatically deleted via CASCADE foreign key constraint if it exists await self.session.delete(chatbot) # Flush to ensure deletion is staged diff --git a/backend/service/tool/faq_config_service.py b/backend/service/tool/faq_config_service.py new file mode 100644 index 000000000..b4667f446 --- /dev/null +++ b/backend/service/tool/faq_config_service.py @@ -0,0 +1,265 @@ +"""FAQ Config Service layer for database operations.""" + +from datetime import datetime, timezone +from typing import Optional +from sqlmodel import select +from sqlmodel.ext.asyncio.session import AsyncSession +from common.knowledgebase.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_FAQ_SIMILARITY_THRESHOLD, FAQ_KNOWLEDGEBASE_NAME +from common.knowledgebase.types import VectorIndexRetrievalType +from loguru import logger + +from db.models.faq_config import FAQConfigCreate +from db.models.chatbot import ChatBotEntity +from db.models.knowledgebase.knowledgebase import KnowledgebaseCreate, RetrievalConfig +from service.knowledgebase.knowledgebase_service import KnowledgebaseService + + +class FAQConfigService: + """Service layer for FAQ Config operations using dependency injection.""" + + def __init__(self, session: AsyncSession): + """ + Initialize FAQConfigService with a database session. + + Args: + session: Database session (injected dependency) + """ + self.session = session + + def _get_default_faq_config(self) -> dict: + """Get default FAQ config values.""" + return { + "active": True, + "similarity_threshold": DEFAULT_FAQ_SIMILARITY_THRESHOLD, + "embedding_model": DEFAULT_EMBEDDING_MODEL, + "enable_question_in_retrieval": True, + "enable_question_in_response": True, + "enable_answer_in_retrieval": False, + "enable_answer_in_response": True, + "return_direct": False, + "kb_id": None, + } + + async def get_faq_config_by_chatbot_id( + self, chatbot_id: str, tenant_id: str + ) -> Optional[FAQConfigCreate]: + """ + Get FAQ Config by chatbot_id. + + Args: + chatbot_id: Chatbot ID + tenant_id: Tenant ID + + Returns: + FAQConfigCreate if found, None otherwise + """ + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.id == chatbot_id, + ChatBotEntity.tenant_id == tenant_id, + ) + ) + chatbot = chatbot.first() + if not chatbot or not chatbot.faq_config: + return None + + # Convert dict to FAQConfigCreate + return FAQConfigCreate.model_validate(chatbot.faq_config) + + async def get_or_create_faq_config( + self, chatbot_id: str, tenant_id: str + ) -> FAQConfigCreate: + """ + Get or create a FAQ config for a chatbot. + + Args: + chatbot_id: Chatbot ID + tenant_id: Tenant ID + + Returns: + FAQConfigCreate representing the FAQ config + """ + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.id == chatbot_id, + ChatBotEntity.tenant_id == tenant_id, + ) + ) + chatbot = chatbot.first() + + if not chatbot: + raise ValueError(f"Chatbot '{chatbot_id}' 不存在。") + + # If faq_config exists and is not empty, return it + if chatbot.faq_config: + logger.info( + f"Found existing FAQ config for chatbot_id: {chatbot_id}" + ) + return FAQConfigCreate.model_validate(chatbot.faq_config) + + # Create new FAQ config with default values + default_config = self._get_default_faq_config() + chatbot.faq_config = default_config + self.session.add(chatbot) + + await self.session.flush() + await self.session.refresh(chatbot) + logger.info( + f"Created FAQ config for chatbot_id: {chatbot_id}" + ) + return FAQConfigCreate.model_validate(default_config) + + async def update_faq_config( + self, chatbot_id: str, update_data: FAQConfigCreate, tenant_id: str + ) -> FAQConfigCreate: + """ + Update FAQ config for a chatbot. + Note: Caller is responsible for committing the session. + + Args: + chatbot_id: Chatbot ID + update_data: Updated FAQ Config data + tenant_id: Tenant ID + + Returns: + Updated FAQConfigCreate + + Raises: + ValueError: If Chatbot not found + """ + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.id == chatbot_id, + ChatBotEntity.tenant_id == tenant_id, + ) + ) + chatbot = chatbot.first() + + if not chatbot: + raise ValueError(f"Chatbot '{chatbot_id}' 不存在。") + + logger.info(f"Updating FAQ Config for chatbot {chatbot_id} with data: {update_data}") + + # Get current config or use defaults + current_config = chatbot.faq_config.copy() if chatbot.faq_config else self._get_default_faq_config() + + # Update fields from update_data + update_dict = update_data.model_dump(exclude_unset=True) + current_config.update(update_dict) + + # Update chatbot's faq_config + chatbot.faq_config = current_config + chatbot.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + self.session.add(chatbot) + + # Flush to ensure changes are staged + await self.session.flush() + await self.session.refresh(chatbot) + + logger.info(f"Updated FAQ Config for chatbot: {chatbot_id}") + return FAQConfigCreate.model_validate(chatbot.faq_config) + + async def update_faq_config_with_sync( + self, + app_id: str, + chatbot_id: str, + update_data: FAQConfigCreate, + tenant_id: str, + knowledgebase_service: Optional[KnowledgebaseService] = None + ) -> FAQConfigCreate: + """ + Update FAQ config with full synchronization logic: + - Get or create FAQ config + - Sync chatbot.enable_faq with faq_config.active + - Update FAQ config + - Update corresponding knowledgebase if embedding_model or similarity_threshold changed + + Note: Caller is responsible for committing the session. + + Args: + app_id: Chatbot app_id (used for knowledgebase name) + chatbot_id: Chatbot ID + update_data: Updated FAQ Config data + tenant_id: Tenant ID + knowledgebase_service: Optional KnowledgebaseService for updating knowledgebase + + Returns: + Updated FAQConfigCreate + + Raises: + ValueError: If Chatbot not found + """ + # Get chatbot entity + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.id == chatbot_id, + ChatBotEntity.tenant_id == tenant_id, + ) + ) + chatbot = chatbot.first() + + if not chatbot: + raise ValueError(f"Chatbot '{chatbot_id}' 不存在。") + + # Get or create FAQ config (this will use the same chatbot entity if it exists) + await self.get_or_create_faq_config( + chatbot_id=chatbot_id, tenant_id=tenant_id + ) + + # Refresh chatbot to get latest state + await self.session.refresh(chatbot) + + # Sync chatbot.enable_faq with faq_config.active + if update_data.active is not None: + if chatbot.enable_faq != update_data.active: + chatbot.enable_faq = update_data.active + logger.info(f"Synced chatbot.enable_faq to {update_data.active} for chatbot {chatbot_id}") + + # Update FAQ config + updated_faq_config = await self.update_faq_config( + chatbot_id=chatbot_id, + update_data=update_data, + tenant_id=tenant_id + ) + + # Update corresponding knowledgebase if embedding_model or similarity_threshold changed + if knowledgebase_service and (update_data.embedding_model is not None or update_data.similarity_threshold is not None): + kb_name = f"{app_id}_{FAQ_KNOWLEDGEBASE_NAME}" + kb = await knowledgebase_service.get_knowledgebase_by_name(kb_name, tenant_id=tenant_id) + + if kb: + # Prepare update data for knowledgebase + kb_update_data = KnowledgebaseCreate() + update_fields = [] + + # Update embedding_model if provided + if update_data.embedding_model is not None: + kb_update_data.embedding_model = update_data.embedding_model + update_fields.append(f"embedding_model={update_data.embedding_model}") + + # Update retrieval_config.similarity_threshold if provided + if update_data.similarity_threshold is not None: + # Get current retrieval_config or create default + current_retrieval_config = RetrievalConfig.model_validate(kb.retrieval_config) if kb.retrieval_config else RetrievalConfig( + retrieval_mode=VectorIndexRetrievalType.vector, + top_k=1, + enable_rerank=False, + rerank_top_k=None, + vector_weight=1.0, + similarity_threshold=update_data.similarity_threshold, + ) + # Update similarity_threshold + current_retrieval_config.similarity_threshold = update_data.similarity_threshold + kb_update_data.retrieval_config = current_retrieval_config + update_fields.append(f"similarity_threshold={update_data.similarity_threshold}") + + # Update knowledgebase only if there are fields to update + if kb_update_data.embedding_model is not None or kb_update_data.retrieval_config is not None: + await knowledgebase_service.update_knowledgebase( + kb_id=kb.id, + update_data=kb_update_data, + tenant_id=tenant_id + ) + logger.info(f"Updated FAQ knowledgebase {kb_name} with {', '.join(update_fields)}") + + return updated_faq_config diff --git a/backend/service/tool/faq_item_service.py b/backend/service/tool/faq_item_service.py new file mode 100644 index 000000000..2dc6b16d0 --- /dev/null +++ b/backend/service/tool/faq_item_service.py @@ -0,0 +1,359 @@ +"""FAQ Item Service layer for database operations.""" + +from datetime import datetime, timezone +from typing import Optional, List +from sqlmodel import select, func +from sqlmodel.ext.asyncio.session import AsyncSession +from loguru import logger + +from db.models.faq_item import FAQItemCreate, FAQItemEntity +from db.models.faq_config import FAQConfigCreate +from db.models.chatbot import ChatBotEntity +from common.chat.response_model import PagedResult +from service.knowledgebase.knowledgebase_service import KnowledgebaseService +from service.knowledgebase.rag_service import RagService +from pairag.file.utils.tokenization import estimate_tokens_in_text +from llama_index.core.schema import TextNode + + +class FAQItemService: + """Service layer for FAQ Item entity CRUD operations using dependency injection.""" + + def __init__(self, session: AsyncSession): + """ + Initialize FAQItemService with a database session. + + Args: + session: Database session (injected dependency) + """ + self.session = session + + async def get_faq_knowledgebase(self, chatbot_id: str, tenant_id: str): + """ + Get FAQ knowledgebase for the given chatbot. + + Args: + chatbot_id: Chatbot ID + tenant_id: Tenant ID + + Returns: + Knowledgebase entity if found, None otherwise + """ + # Get chatbot to get app_id + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.app_id == chatbot_id, ChatBotEntity.tenant_id == tenant_id + ) + ) + chatbot = chatbot.first() + if not chatbot: + return None + + # Convert dict to FAQConfigCreate object + if not chatbot.faq_config: + return None + + try: + faq_config = FAQConfigCreate.model_validate(chatbot.faq_config) + except Exception as e: + logger.warning(f"Failed to validate FAQ config for chatbot {chatbot_id}: {e}") + return None + + if not faq_config.kb_id: + return None + + knowledgebase_service = KnowledgebaseService(self.session) + return await knowledgebase_service.get_knowledgebase(faq_config.kb_id, tenant_id=tenant_id) + + async def save_faq_to_knowledgebase( + self, faq_item: FAQItemEntity, tenant_id: str, rag_service: RagService + ) -> None: + """ + Save FAQ item to knowledgebase. + + Args: + faq_item: FAQ Item entity + tenant_id: Tenant ID + """ + try: + if not faq_item.question or not faq_item.answer: + logger.warning( + f"FAQ item {faq_item.id} has no question or answer, skipping save to KB" + ) + return + + # Get FAQ knowledgebase + kb = await self.get_faq_knowledgebase(faq_item.chatbot_id, tenant_id) + if not kb: + logger.warning( + f"FAQ knowledgebase not found for chatbot {faq_item.chatbot_id}, skipping save to KB" + ) + return + + # Get FAQ config from chatbot to determine what to include in chunk_text + chatbot = await self.session.exec( + select(ChatBotEntity).where( + ChatBotEntity.app_id == faq_item.chatbot_id, + ChatBotEntity.tenant_id == tenant_id, + ) + ) + chatbot = chatbot.first() + + faq_config = None + if chatbot and chatbot.faq_config: + faq_config = FAQConfigCreate.model_validate(chatbot.faq_config) + + # Build chunk_text based on faq_config settings + chunk_parts = [] + if faq_config: + if faq_config.enable_question_in_retrieval: + chunk_parts.append(f"{faq_item.question}") + if faq_config.enable_answer_in_retrieval: + chunk_parts.append(f"{faq_item.answer}") + else: + chunk_parts.append(f"{faq_item.question}") + + chunk_text = "\n".join(chunk_parts) if chunk_parts else "" + + # Create metadata for TextNode + node_metadata = { + "faq_item_id": faq_item.id, + "chatbot_id": faq_item.chatbot_id, + "question": faq_item.question, + "answer": faq_item.answer, + "token_count": estimate_tokens_in_text(chunk_text), + } + + # Create TextNode directly (no KbChunkEntity needed for FAQ items) + kb_node = TextNode( + id_=faq_item.id, + text=chunk_text, + metadata=node_metadata, + ) + + + if faq_item.active: + await rag_service.ainsert(kb_id=kb.id, nodes=[kb_node], tenant_id=tenant_id) + logger.info( + f"Inserted FAQ item {faq_item.id} into knowledgebase {kb.id}" + ) + else: + logger.info( + f"FAQ item {faq_item.id} is inactive, skipping vector store insertion" + ) + + except Exception as e: + logger.error(f"Failed to save FAQ item to knowledgebase: {e}") + + async def delete_faq_from_knowledgebase( + self, faq_item: FAQItemEntity, tenant_id: str, rag_service: RagService + ) -> None: + """ + Delete FAQ item from knowledgebase. + + Args: + faq_item: FAQ Item entity + tenant_id: Tenant ID + """ + try: + # Get FAQ knowledgebase + kb = await self.get_faq_knowledgebase(faq_item.chatbot_id, tenant_id) + if not kb: + logger.warning( + f"FAQ knowledgebase not found for chatbot {faq_item.chatbot_id}, skipping delete from KB" + ) + return + + # Delete from vector store using faq_item.id as node_id + await rag_service.adelete(kb_id=kb.id, node_ids=[faq_item.id], tenant_id=tenant_id) + logger.info( + f"Deleted FAQ item {faq_item.id} from knowledgebase {kb.id}" + ) + + except Exception as e: + logger.error(f"Failed to delete FAQ item from knowledgebase: {e}") + # Don't raise exception, just log the error + + async def get_faq_item(self, id: str, tenant_id: str) -> Optional[FAQItemEntity]: + """ + Get a single FAQ Item entity by ID. + + Args: + id: FAQ Item entity ID + tenant_id: Tenant ID + + Returns: + FAQItemEntity if found, None otherwise + """ + faq_items = await self.session.exec( + select(FAQItemEntity).where( + FAQItemEntity.id == id, FAQItemEntity.tenant_id == tenant_id + ) + ) + return faq_items.first() + + async def list_faq_items( + self, + chatbot_id: str, + tenant_id: str = None, + page: int = 1, + size: int = 100, + ) -> PagedResult[List[FAQItemEntity]]: + """ + List FAQ Item entities with pagination. + + Args: + chatbot_id: Chatbot ID + tenant_id: Tenant ID + page: Page number (1-indexed) + size: Page size + + Returns: + PagedResult containing list of FAQItemEntity and pagination metadata + """ + # Build base query + base_query = select(FAQItemEntity).where( + FAQItemEntity.chatbot_id == chatbot_id, + FAQItemEntity.tenant_id == tenant_id, + ) + + # Get total count + count_query = select(func.count()).select_from(base_query) + total_result = await self.session.exec(count_query) + total = total_result.one_or_none() or 0 + + # Get paginated results + offset = (page - 1) * size + paginated_query = ( + base_query.offset(offset).limit(size).order_by(FAQItemEntity.created_at.desc()) + ) + results = await self.session.exec(paginated_query) + faq_items = list(results.all()) + + # Calculate pages + pages = (total + size - 1) // size if total > 0 else 0 + + return PagedResult( + items=faq_items, + total=total, + pages=pages, + page=page, + size=size, + ) + + async def create_faq_item( + self, + chatbot_id: str, + faq_item_data: FAQItemCreate, + tenant_id: str, + ) -> FAQItemEntity: + """ + Create a new FAQ Item entity. + Note: Caller is responsible for committing the session. + + Args: + chatbot_id: Chatbot ID + faq_item_data: FAQ Item creation data + tenant_id: Tenant ID + + Returns: + Created FAQItemEntity (not yet committed) + """ + faq_item = FAQItemEntity.model_validate( + faq_item_data, + update={"chatbot_id": chatbot_id, "tenant_id": tenant_id}, + ) + self.session.add(faq_item) + + try: + # Flush to get the ID, but don't commit + await self.session.flush() + await self.session.refresh(faq_item) + + logger.info( + f"Created FAQ Item entity: {faq_item.id} (chatbot_id: {chatbot_id})" + ) + + + return faq_item + except Exception as e: + logger.error(f"Error creating FAQ Item: {e}") + raise ValueError(f"创建FAQ条目失败: {e}") from e + + async def update_faq_item( + self, id: str, update_data: FAQItemCreate, tenant_id: str, rag_service: Optional[RagService] = None + ) -> FAQItemEntity: + """ + Update an existing FAQ Item entity. + Note: Caller is responsible for committing the session. + + Args: + id: FAQ Item entity ID + update_data: Updated FAQ Item data + tenant_id: Tenant ID + + Returns: + Updated FAQItemEntity (not yet committed) + + Raises: + ValueError: If FAQ Item entity not found + """ + faq_item = await self.get_faq_item(id=id, tenant_id=tenant_id) + if not faq_item: + raise ValueError(f"FAQ条目 '{id}' 不存在。") + + logger.info(f"Updating FAQ Item {id} with data: {update_data}") + + # Update fields + if update_data.question is not None: + faq_item.question = update_data.question + if update_data.answer is not None: + faq_item.answer = update_data.answer + if update_data.chatbot_id is not None: + faq_item.chatbot_id = update_data.chatbot_id + if update_data.file_id is not None: + faq_item.file_id = update_data.file_id + if update_data.active is not None: + faq_item.active = update_data.active + + faq_item.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + self.session.add(faq_item) + + # Flush to ensure changes are staged + await self.session.flush() + await self.session.refresh(faq_item) + + logger.info(f"Updated FAQ Item entity: {faq_item.id}") + + # Update FAQ in knowledgebase (delete old, insert new if active) + await self.delete_faq_from_knowledgebase(faq_item, tenant_id, rag_service) + await self.save_faq_to_knowledgebase(faq_item, tenant_id, rag_service) + + return faq_item + + async def delete_faq_item(self, id: str, tenant_id: str, rag_service: Optional[RagService] = None) -> None: + """ + Delete a FAQ Item entity. + Note: Caller is responsible for committing the session. + + Args: + id: FAQ Item entity ID + tenant_id: Tenant ID + + Raises: + ValueError: If FAQ Item entity not found + """ + faq_item = await self.get_faq_item(id=id, tenant_id=tenant_id) + if not faq_item: + raise ValueError(f"FAQ条目 '{id}' 不存在。") + + # Delete from knowledgebase first + await self.delete_faq_from_knowledgebase(faq_item, tenant_id, rag_service) + + # Delete from database (staged, not committed) + await self.session.delete(faq_item) + + # Flush to ensure deletion is staged + await self.session.flush() + + logger.info(f"Deleted FAQ Item entity: {id}") diff --git a/backend/tools/attachments/file_reader.py b/backend/tools/attachments/file_reader.py index 5abe0964a..34166c611 100644 --- a/backend/tools/attachments/file_reader.py +++ b/backend/tools/attachments/file_reader.py @@ -40,5 +40,6 @@ async def aread_file_content( async_fn=aread_file_content, name="read-file", description="根据提供的附件ID读取文件的内容。\n参数:\n- file_id (str, 必需): 要读取的文件的ID。", + return_direct=False, ) return read_file_tool diff --git a/backend/tools/attachments/file_searcher.py b/backend/tools/attachments/file_searcher.py index 65f04dae8..f8f9441ae 100644 --- a/backend/tools/attachments/file_searcher.py +++ b/backend/tools/attachments/file_searcher.py @@ -67,5 +67,6 @@ async def file_retrieve_handler( - doc_ids (List[str]): The IDs of the documents to search. - kwargs (dict, optional): Additional arguments for the tool. """, + return_direct=False, ) return search_tool diff --git a/backend/tools/attachments/image_parser.py b/backend/tools/attachments/image_parser.py index 3d419f3dc..d1e6ee485 100644 --- a/backend/tools/attachments/image_parser.py +++ b/backend/tools/attachments/image_parser.py @@ -101,7 +101,8 @@ async def aget_image_analysis_func( name="image-parser", description="""解析上传的图片内容。适用于用户提问涉及图片中的信息(如图表、文字、产品图等)。 参数: -- question: 可选,用户想问的具体问题,例如“图中智能床的价格是多少?”、“请提取表格数据”等。 +- question: 可选,用户想问的具体问题,例如"图中智能床的价格是多少?"、"请提取表格数据"等。 返回:包含图片分析结果的 JSON 对象。""", + return_direct=False, ) return image_parser_tool diff --git a/backend/tools/code/code_sandbox_tool.py b/backend/tools/code/code_sandbox_tool.py index 834bc53a9..662625b2c 100644 --- a/backend/tools/code/code_sandbox_tool.py +++ b/backend/tools/code/code_sandbox_tool.py @@ -678,6 +678,7 @@ async def _wrapped_list_files(file_dir_path: str, sandbox_id: str = None): async_fn=_wrapped_list_files, name="list-sandbox-files", description=description, + return_direct=False, ) async def aupload_files_to_code_sandbox(self, file_ids: List[str], sandbox_id: str = None): diff --git a/backend/tools/knowledgebase/faq_tool.py b/backend/tools/knowledgebase/faq_tool.py new file mode 100644 index 000000000..2bf6379fa --- /dev/null +++ b/backend/tools/knowledgebase/faq_tool.py @@ -0,0 +1,145 @@ +from service.knowledgebase.rag_service import RagService +from service.tool.chatapp_service import ChatappService +from service.tool.faq_config_service import FAQConfigService +from db.models.faq_config import FAQConfigCreate +from typing import Annotated, Optional +from functools import partial +from llama_index.core.tools import FunctionTool +import json +from loguru import logger + + +async def aget_faq_result( + query: str, + chatapp_id: str, + user_id: str | None = None, + rag_service: RagService | None = None, + chatapp_service: ChatappService | None = None, + faq_config_service: FAQConfigService | None = None, + tenant_id: str = None, +) -> str: + """Get FAQ search result from FAQ knowledgebase""" + logger.info(f"Searching FAQ with chatapp_id {chatapp_id} and user {user_id}.") + + chatbot = await chatapp_service.get_chatapp_by_app_id( + app_id=chatapp_id, + tenant_id=tenant_id + ) + if not chatbot: + raise ValueError(f"应用 '{chatapp_id}' 不存在。") + + # Convert dict to FAQConfigCreate object + kb_id = None + if chatbot.faq_config: + try: + faq_config = FAQConfigCreate.model_validate(chatbot.faq_config) + kb_id = faq_config.kb_id + except Exception as e: + logger.warning(f"Failed to validate FAQ config for chatbot {chatapp_id}: {e}") + + kb = await rag_service.get_knowledgebase(kb_id=kb_id, tenant_id=tenant_id) + + if not kb: + raise ValueError(f"FAQ知识库 '{kb_id}' 不存在。") + + + records = await rag_service.aquery( + query=query, + user_id=user_id, + kb_id=kb.id, + tenant_id=tenant_id, + ) + + logger.info( + f"Retrieved {len(records)} FAQ results for query '{query}' from knowledgebase {kb.id}." + ) + + faq_config = None + try: + faq_config = await faq_config_service.get_faq_config_by_chatbot_id( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + except Exception as e: + logger.warning(f"Failed to get FAQ config: {e}, using defaults") + + question_in_response = faq_config.enable_question_in_response if faq_config else True + answer_in_response = faq_config.enable_answer_in_response if faq_config else True + + records_dict = [] + for record in records: + record_dict = record.model_dump() + metadata = record_dict.get('metadata', {}) or {} + + question = metadata.get('question', '') or '' + answer = metadata.get('answer', '') or '' + + content_parts = [] + if not faq_config.return_direct: + if question_in_response and question: + content_parts.append(f"问题:{question}") + if answer_in_response and answer: + content_parts.append(f"答案:{answer}") + + if content_parts: + record_dict['content'] = '\n'.join(content_parts) + else: + record_dict['content'] = answer + + records_dict.append(record_dict) + + return json.dumps({"result": records_dict}, ensure_ascii=False) + + +async def aget_faq_tool( + chatapp_id: str, + tenant_id: str, + user_id: Optional[str] = None, + rag_service: RagService = None, + chatapp_service: ChatappService = None, + faq_config_service: FAQConfigService = None, +): + """Create a FAQ search tool for the given chatapp_id.""" + # Get faq_config to determine return_direct value + return_direct = False + try: + chatbot = await chatapp_service.get_chatapp_by_app_id( + app_id=chatapp_id, + tenant_id=tenant_id + ) + if chatbot: + faq_config = await faq_config_service.get_faq_config_by_chatbot_id( + chatbot_id=chatbot.id, tenant_id=tenant_id + ) + if faq_config: + return_direct = faq_config.return_direct if faq_config.return_direct is not None else False + except Exception as e: + logger.warning(f"Failed to get FAQ config for return_direct: {e}, using default False") + + aquery_faq_func = partial( + aget_faq_result, + chatapp_id=chatapp_id, + user_id=user_id, + rag_service=rag_service, + chatapp_service=chatapp_service, + faq_config_service=faq_config_service, + tenant_id=tenant_id, + ) + + + async def query_faq_handler( + query: Annotated[ + str, + "根据上下文添加必要的背景信息,改写一个新的独立问题,使问题更完整,注意指代消解、完善主语等", + ] = "", + ): + return await aquery_faq_func( + query=query, + ) + + search_faq_tool = FunctionTool.from_defaults( + async_fn=query_faq_handler, + name=f"search-faq-{chatapp_id}", + description="根据上下文从FAQ知识库中搜索和用户查询相关的内容。", + return_direct=return_direct, + ) + return search_faq_tool diff --git a/backend/tools/knowledgebase/knowledgebase_tool.py b/backend/tools/knowledgebase/knowledgebase_tool.py index 22fb7fdb6..1e8c8bcc2 100644 --- a/backend/tools/knowledgebase/knowledgebase_tool.py +++ b/backend/tools/knowledgebase/knowledgebase_tool.py @@ -55,5 +55,6 @@ async def query_knowledgebase_handler( async_fn=query_knowledgebase_handler, name=f"search-knowledgebase-{kb_id}", description=f"根据上下文从知识库中搜索和用户查询相关的内容。\n知识库名称: {knowledgebase.name}\n知识库描述: {knowledgebase.description}\n", + return_direct=False, ) return search_knowledgebase_tool diff --git a/backend/tools/plan/plan_tool.py b/backend/tools/plan/plan_tool.py index 01d75172d..17371355d 100644 --- a/backend/tools/plan/plan_tool.py +++ b/backend/tools/plan/plan_tool.py @@ -32,6 +32,7 @@ async def plan_func( async_fn=plan_func, name="planning-tool", description=PLAN_TOOL_DESCRIPTION, + return_direct=False, ) return plan_tool @@ -53,6 +54,7 @@ async def response_func(): async_fn=response_func, name="respond-tool", description=RESPONSE_TOOL_DESCRIPTION, + return_direct=False, ) return response_tool diff --git a/backend/tools/search/visit_webpage.py b/backend/tools/search/visit_webpage.py index b789b82b7..7cf829c90 100644 --- a/backend/tools/search/visit_webpage.py +++ b/backend/tools/search/visit_webpage.py @@ -179,6 +179,7 @@ async def visit_webpage_handler( ] } """, + return_direct=False, ) return visit_tool diff --git a/backend/tools/think/simple_think_tool.py b/backend/tools/think/simple_think_tool.py index 864968f49..1c7fc83ac 100644 --- a/backend/tools/think/simple_think_tool.py +++ b/backend/tools/think/simple_think_tool.py @@ -48,6 +48,7 @@ async def simple_think_handler(thought: str): async_fn=simple_think_handler, name="think", description="记录思考内容。用于复杂推理或缓存记忆。", + return_direct=False, ) openai_tools = [] tools_name_to_fn = {} diff --git a/backend/tools/think/think_and_planning_tool.py b/backend/tools/think/think_and_planning_tool.py index 151c4651c..3c3dbc615 100644 --- a/backend/tools/think/think_and_planning_tool.py +++ b/backend/tools/think/think_and_planning_tool.py @@ -75,6 +75,7 @@ async def simple_think_handler( async_fn=simple_think_handler, name="think-and-planning", description="这是用于系统化思考与规划的工具,支持用户在面对复杂问题或任务时,分阶段梳理思考、规划和行动步骤。工具强调思考(thought)、计划(plan)与实际行动(action)的结合,通过编号(thoughtNumber)追踪过程。该工具不会获取新信息或更改数据库,只会将想法附加到记忆中。当需要复杂推理或某种缓存记忆时,可以使用它。", + return_direct=False, ) return think_tool diff --git a/backend/utils/upload_file_utils.py b/backend/utils/upload_file_utils.py index d86f2e7b2..e71070393 100644 --- a/backend/utils/upload_file_utils.py +++ b/backend/utils/upload_file_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from fastapi import UploadFile from pairag.file.models.file_item import FileItem from pairag.file.store.file_store_helper import file_store @@ -50,6 +50,7 @@ class ParseFileTask(BaseModel): class StartParseTaskRequest(BaseModel): files: List[ParseFileTask] + chunk_config: Optional[dict] = None async def upload_file_names_async( diff --git a/frontend/app/api/config/apps/[app_id]/faq-config/route.ts b/frontend/app/api/config/apps/[app_id]/faq-config/route.ts new file mode 100644 index 000000000..2139ae416 --- /dev/null +++ b/frontend/app/api/config/apps/[app_id]/faq-config/route.ts @@ -0,0 +1,12 @@ +// app/api/proxy/route.js +import { NextRequest } from 'next/server'; +import { proxyRequest } from '@/app/api/proxy'; + +export async function GET(request: NextRequest) { + return proxyRequest(request); +} + +export async function PUT(request: NextRequest) { + return proxyRequest(request); +} + diff --git a/frontend/app/api/config/apps/[app_id]/faq-files/route.ts b/frontend/app/api/config/apps/[app_id]/faq-files/route.ts new file mode 100644 index 000000000..077e4acd9 --- /dev/null +++ b/frontend/app/api/config/apps/[app_id]/faq-files/route.ts @@ -0,0 +1,8 @@ +// app/api/proxy/route.js +import { NextRequest } from 'next/server'; +import { proxyRequest } from '@/app/api/proxy'; + +export async function POST(request: NextRequest) { + return proxyRequest(request); +} + diff --git a/frontend/app/api/config/apps/[app_id]/faqs/[faq_id]/route.ts b/frontend/app/api/config/apps/[app_id]/faqs/[faq_id]/route.ts new file mode 100644 index 000000000..7f08ff7a0 --- /dev/null +++ b/frontend/app/api/config/apps/[app_id]/faqs/[faq_id]/route.ts @@ -0,0 +1,16 @@ +// app/api/proxy/route.js +import { NextRequest } from 'next/server'; +import { proxyRequest } from '@/app/api/proxy'; + +export async function GET(request: NextRequest) { + return proxyRequest(request); +} + +export async function PUT(request: NextRequest) { + return proxyRequest(request); +} + +export async function DELETE(request: NextRequest) { + return proxyRequest(request); +} + diff --git a/frontend/app/api/config/apps/[app_id]/faqs/route.ts b/frontend/app/api/config/apps/[app_id]/faqs/route.ts new file mode 100644 index 000000000..602a9149e --- /dev/null +++ b/frontend/app/api/config/apps/[app_id]/faqs/route.ts @@ -0,0 +1,20 @@ +// app/api/proxy/route.js +import { NextRequest } from 'next/server'; +import { proxyRequest } from '@/app/api/proxy'; + +export async function GET(request: NextRequest) { + return proxyRequest(request); +} + +export async function POST(request: NextRequest) { + return proxyRequest(request); +} + +export async function PUT(request: NextRequest) { + return proxyRequest(request); +} + +export async function DELETE(request: NextRequest) { + return proxyRequest(request); +} + diff --git a/frontend/app/api/config/knowledgebases/[kb_id]/files/[file_id]/chunk_config/route.ts b/frontend/app/api/config/knowledgebases/[kb_id]/files/[file_id]/chunk_config/route.ts new file mode 100644 index 000000000..602a9149e --- /dev/null +++ b/frontend/app/api/config/knowledgebases/[kb_id]/files/[file_id]/chunk_config/route.ts @@ -0,0 +1,20 @@ +// app/api/proxy/route.js +import { NextRequest } from 'next/server'; +import { proxyRequest } from '@/app/api/proxy'; + +export async function GET(request: NextRequest) { + return proxyRequest(request); +} + +export async function POST(request: NextRequest) { + return proxyRequest(request); +} + +export async function PUT(request: NextRequest) { + return proxyRequest(request); +} + +export async function DELETE(request: NextRequest) { + return proxyRequest(request); +} + diff --git a/frontend/app/api/proxy.ts b/frontend/app/api/proxy.ts index 5203e6c08..968d2072b 100644 --- a/frontend/app/api/proxy.ts +++ b/frontend/app/api/proxy.ts @@ -68,23 +68,73 @@ export async function proxyRequest(request: NextRequest) { } // 创建 AbortController 用于超时控制 - const timeoutMs = parseInt(process.env.PROXY_TIMEOUT_MS || '60000', 10); // 默认 60 秒 + // 对于可能返回流式响应或需要LLM调用的请求,使用更长的超时时间 + const defaultTimeoutMs = parseInt(process.env.PROXY_TIMEOUT_MS || '60000', 10); // 默认 60 秒 + const streamingTimeoutMs = parseInt(process.env.PROXY_STREAMING_TIMEOUT_MS || '300000', 10); // 流式响应默认 5 分钟 + + // 判断是否是可能返回流式响应或需要LLM调用的请求路径 + // 包括: + // - /threads/* 下的所有路径(可能涉及LLM调用,如生成标题、消息等) + // - /chat/completions 和 /chat(流式响应) + const isPotentialStreamingPath = pathname.includes('/threads/') || + pathname.includes('/chat/completions') || + pathname.includes('/chat'); + + // 对于可能返回流式响应或需要LLM调用的请求,使用更长的超时时间 + const initialTimeoutMs = isPotentialStreamingPath ? streamingTimeoutMs : defaultTimeoutMs; + const controller = new AbortController(); let timeoutId: NodeJS.Timeout | null = null; try { - timeoutId = setTimeout(() => controller.abort(), timeoutMs); + // 根据请求路径设置初始超时时间 + timeoutId = setTimeout(() => controller.abort(), initialTimeoutMs); const res = await fetch(upstreamUrl.toString(), { method, headers, body, signal: controller.signal, + // 添加 keepalive 选项,保持连接活跃 + keepalive: true, }); + // 检查是否是流式响应 + const contentType = res.headers.get('content-type') || ''; + const isStreaming = contentType.includes('text/event-stream') || + contentType.includes('stream') || + res.headers.get('transfer-encoding') === 'chunked'; + + if (isStreaming && res.body) { + // 流式响应:清除当前超时,使用更长的超时时间 + if (timeoutId) clearTimeout(timeoutId); + + // 对于流式响应,创建一个新的超时控制器,使用更长的超时时间 + // 注意:这里我们不能直接修改signal,但可以在流式传输过程中监控 + // 实际上,对于流式响应,我们应该让客户端控制超时,而不是在代理层强制超时 + // 流式响应:直接传递流,不设置超时限制(由客户端或Next.js处理) + return new NextResponse(res.body, { + status: res.status, + statusText: res.statusText, + headers: res.headers, + }); + } + + // 非流式响应:清除超时(响应已完全接收) if (timeoutId) clearTimeout(timeoutId); - // 读取响应数据 + // 检查响应是否正常 + if (!res.ok && !res.body) { + return NextResponse.json( + { + error: 'Proxy request failed', + message: `Backend returned status ${res.status} without body`, + }, + { status: res.status } + ); + } + + // 非流式响应:读取完整数据 const responseData = await res.blob(); // 通用处理(支持 JSON、text、binary) const responseHeaders = new Headers(res.headers); responseHeaders.set('content-length', responseData.size.toString()); @@ -99,25 +149,57 @@ export async function proxyRequest(request: NextRequest) { }); } catch (error: any) { if (timeoutId) clearTimeout(timeoutId); - console.log("Proxy request failed: ", error); - // 处理超时错误 - if (error.name === 'AbortError' || error.code === 'UND_ERR_HEADERS_TIMEOUT') { + // 记录错误详情用于调试 + console.error("Proxy request failed: ", { + name: error.name, + message: error.message, + code: error.code, + cause: error.cause, + stack: error.stack + }); + + // 处理连接关闭错误 + if (error.cause?.code === 'UND_ERR_SOCKET' || + error.message?.includes('other side closed') || + error.message?.includes('fetch failed') || + error.message?.includes('ECONNREFUSED') || + error.message?.includes('ENOTFOUND')) { + return NextResponse.json( + { + error: 'Proxy connection closed', + message: 'Backend connection was closed unexpectedly. This may happen if the request takes too long or the backend service restarted.', + details: error.cause?.message || error.message, + code: error.cause?.code || 'CONNECTION_CLOSED' + }, + { status: 502 } // Bad Gateway - 后端服务问题 + ); + } + + // 处理超时错误(包括AbortError) + if (error.name === 'AbortError' || + error.code === 'UND_ERR_HEADERS_TIMEOUT' || + error.code === 20 || // DOMException.ABORT_ERR + error.message?.includes('aborted') || + error.message?.includes('This operation was aborted')) { return NextResponse.json( { error: 'Proxy request timeout', - message: `Request exceeded timeout of ${timeoutMs}ms`, - details: error.message + message: `Request exceeded timeout of ${initialTimeoutMs}ms. ${isPotentialStreamingPath ? 'This is a streaming endpoint, which may take longer to respond.' : 'Please try again or contact support if the issue persists.'}`, + details: error.message, + code: 'TIMEOUT' }, { status: 504 } ); } + // 处理其他错误 return NextResponse.json( { error: 'Proxy request failed', message: error.message || String(error), - details: error.cause?.message || error.stack + details: error.cause?.message || error.stack, + code: error.code || 'UNKNOWN_ERROR' }, { status: 500 } ); diff --git a/frontend/app/apps/[appId]/page.tsx b/frontend/app/apps/[appId]/page.tsx index c8b8ce71e..856432290 100644 --- a/frontend/app/apps/[appId]/page.tsx +++ b/frontend/app/apps/[appId]/page.tsx @@ -1,14 +1,107 @@ 'use client'; -import { use } from "react"; +import { use, useState, useEffect } from "react"; import { ChatbotConfigCard } from "../chatbot_config"; - +import { FAQManagement } from "../faq_management"; +import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs"; +import { Button } from "@/components/ui/button"; +import { Badge } from "@/components/ui/badge"; +import { + Breadcrumb, + BreadcrumbItem, + BreadcrumbLink, + BreadcrumbList, + BreadcrumbPage, + BreadcrumbSeparator, +} from '@/components/ui/breadcrumb'; +import { useRouter } from "next/navigation"; +import { useTenantFetch } from "@/hooks/use-tenant-fetch"; +import { Chatbot } from "../chatbot_config"; +import { toast } from "sonner"; export default function ViewChatApp( { params } : { params: Promise<{ appId: string }> } ) { const { appId } = use(params); + const router = useRouter(); + const { tenantFetch } = useTenantFetch(); + const [botConfig, setBotConfig] = useState(null); + const [loading, setLoading] = useState(true); + + useEffect(() => { + fetchAppConfig(); + }, [appId]); + + const fetchAppConfig = async () => { + try { + setLoading(true); + const res = await tenantFetch(`/api/config/apps?app_id=${appId}`); + if (res.ok) { + const data = await res.json(); + setBotConfig(data.data); + } + } catch (error: any) { + console.error('获取应用配置失败:', error); + toast.error('获取应用配置失败'); + } finally { + setLoading(false); + } + }; + + if (loading) { + return
加载中...
; + } + return ( - - - ) +
+
+ + + + + + + + + + {botConfig?.app_id || '应用编辑'} + + + +
+ + ID: {botConfig?.id || ''} + + {botConfig?.description && ( + + {botConfig.description} + + )} +
+
+
+ + + + 应用设置 + + + FAQ管理 + + + + + + + {botConfig && } + + +
+
+ ); } \ No newline at end of file diff --git a/frontend/app/apps/chatbot_config.tsx b/frontend/app/apps/chatbot_config.tsx index 8161455cf..f8aff17cf 100644 --- a/frontend/app/apps/chatbot_config.tsx +++ b/frontend/app/apps/chatbot_config.tsx @@ -53,8 +53,17 @@ interface PromptConfig { act: string; act_with_plan: string; summary: string; -}; +} +interface FAQConfig { + active?: boolean; + similarity_threshold?: number; + embedding_model?: string; + enable_question_in_retrieval?: boolean; + enable_question_in_response?: boolean; + enable_answer_in_retrieval?: boolean; + enable_answer_in_response?: boolean; +} export interface Chatbot { id: string; @@ -63,6 +72,8 @@ export interface Chatbot { enable_search: boolean; enable_agent: boolean; enable_chatdb: boolean; + enable_faq?: boolean; + faq_config?: FAQConfig | null; mcp_ids: string[]; kb_ids: string[]; model_id: string; @@ -90,6 +101,8 @@ const default_chat_config = { updated_at: "", enable_agent: false, enable_chatdb: false, + enable_faq: false, + faq_config: null, enable_input_guardrail: false, enable_output_guardrail: false, guardrail_hint: "作为人工智能助手,我无法回应包含不当或敏感信息的内容。", @@ -502,6 +515,21 @@ export const ChatbotConfigCard: FC = ({ }} /> +
+ + { + setBotConfig({ + ...botConfig, + enable_faq: checked, + }); + }} + /> +