diff --git a/backend/api/v1/mcp/kb_retriever_tool.py b/backend/api/v1/mcp/kb_retriever_tool.py index 48cb2ec6..e39d5e55 100644 --- a/backend/api/v1/mcp/kb_retriever_tool.py +++ b/backend/api/v1/mcp/kb_retriever_tool.py @@ -1,5 +1,4 @@ import traceback -import uuid from typing import List, Optional from loguru import logger from common.chat.models import RetrievalSetting @@ -25,16 +24,12 @@ class NodeResult(BaseModel): text: str = Field(description="The text content of the node") -class RetrievalResult(BaseModel): - total: int - nodes: List[NodeResult] = [] - - class RetrievalToolResponse(BaseModel): status: str status_code: int message: Optional[str] = None - data: RetrievalResult + total: int + nodes: List[NodeResult] = [] request_id: str @@ -47,9 +42,8 @@ async def asearch_knowledgebase( retrieval_setting: Optional[RetrievalSetting] = None, metadata_condition: Optional[MetadataFilteringCondition] = None, rag_service: RagService = None, + request_id: str = None, ) -> RetrievalToolResponse: - request_id = str(uuid.uuid4()) - # TODO: Handle images if needed in the future # For now, images are accepted but not used in retrieval if image_list: @@ -99,10 +93,8 @@ async def asearch_knowledgebase( return RetrievalToolResponse( status="SUCCESS", status_code=200, - data=RetrievalResult( - total=len(nodes), - nodes=nodes - ), + total=len(nodes), + nodes=nodes, message=None, request_id=request_id ) @@ -111,10 +103,8 @@ async def asearch_knowledgebase( return RetrievalToolResponse( status="ERROR", status_code=500, - data=RetrievalResult( - total=0, - nodes=[] - ), + total=0, + nodes=[], message=str(ex), request_id=request_id ) diff --git a/backend/api/v1/retrieval.py b/backend/api/v1/retrieval.py index 82348383..ee048e1b 100644 --- a/backend/api/v1/retrieval.py +++ b/backend/api/v1/retrieval.py @@ -1,15 +1,14 @@ from fastapi import APIRouter, Depends from common.chat.response_model import ResponseModel from api.api_exception import ApiException -from db.db_context import get_db_session from common.chat.models import DocRecord, NewRetrievalResponse, RetrievalRequest -from sqlmodel.ext.asyncio.session import AsyncSession from service.injection import get_rag_service, get_tenant_id from service.knowledgebase.rag_service import RagService from typing import List from common.tool.search_result import SearchResult from fastapi.responses import JSONResponse import traceback +from extensions.trace.context import get_request_id from loguru import logger @@ -20,11 +19,11 @@ ) async def retrieval( retrieval_request: RetrievalRequest, - session: AsyncSession = Depends(get_db_session), tenant_id: str = Depends(get_tenant_id), rag_service: RagService = Depends(get_rag_service), ): - logger.info(f"Retrieval request: {retrieval_request}, tenant_id: {tenant_id}") + request_id = get_request_id() + logger.info(f"Retrieval request: {retrieval_request}, tenant_id: {tenant_id}, request_id: {request_id}") try: search_results: List[SearchResult] = await rag_service.aquery( query=retrieval_request.query, @@ -36,7 +35,7 @@ async def retrieval( tenant_id=tenant_id, ) logger.info( - f"Retrieved {len(search_results)} for query '{retrieval_request.query}'." + f"Retrieved {len(search_results)} for query '{retrieval_request.query}', request_id: {request_id}" ) records = [] for node in search_results: @@ -48,7 +47,10 @@ async def retrieval( )) # 使用统一的响应格式 retrieval_response = NewRetrievalResponse(records=records) - return JSONResponse(status_code=200, content=retrieval_response.model_dump()) + retrieval_response_data = retrieval_response.model_dump() + logger.info(f"Retrieval response: {retrieval_response_data}, request_id: {request_id}") + + return JSONResponse(status_code=200, content=retrieval_response_data) except ValueError as e: logger.error(f"Failed to retrieve: {traceback.format_exc()}") raise ApiException(code=400, message=f"Failed to retrieve: {e}") diff --git a/backend/api/v1/retrieval_tool_api.py b/backend/api/v1/retrieval_tool_api.py index da8fae26..f28e2e9c 100644 --- a/backend/api/v1/retrieval_tool_api.py +++ b/backend/api/v1/retrieval_tool_api.py @@ -2,14 +2,13 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import List, Optional -from sqlmodel.ext.asyncio.session import AsyncSession -from db.db_context import get_db_session from common.chat.models import RetrievalSetting from api.v1.mcp.kb_retriever_tool import asearch_knowledgebase from common.chat.models import MetadataFilteringCondition from service.knowledgebase.rag_service import RagService from service.injection import get_rag_service, get_tenant_id from loguru import logger +from extensions.trace.context import get_request_id retrieval_tool_router = APIRouter() @@ -27,7 +26,6 @@ async def mcp_retrieval( knowledgebase_id: str, request: RetrievalToolRequest, tenant_id: str = Depends(get_tenant_id), - session: AsyncSession = Depends(get_db_session), rag_service: RagService = Depends(get_rag_service), ): """ @@ -35,7 +33,8 @@ async def mcp_retrieval( Input: {"query": "xxx", "images": ["1.jpg", "2.jpg"]} Output: {"status": "SUCCESS", "status_code": 200, "data": {"total": 2, "nodes": [...]}, "request_id": "..."} """ - logger.info(f"Retrieval tool request: {request}, knowledgebase_id: {knowledgebase_id}, tenant_id: {tenant_id}") + request_id = get_request_id() + logger.info(f"Retrieval tool request: {request}, knowledgebase_id: {knowledgebase_id}, tenant_id: {tenant_id}, request_id: {request_id}") result = await asearch_knowledgebase( query=request.query, @@ -45,5 +44,8 @@ async def mcp_retrieval( metadata_condition=request.metadata_condition, rag_service=rag_service, tenant_id=tenant_id, + request_id=request_id, ) - return JSONResponse(status_code=result.status_code, content=result.model_dump()) + result_data = result.model_dump() + logger.info(f"Retrieval tool response: {result_data}, request_id: {request_id}") + return JSONResponse(status_code=result.status_code, content=result_data) diff --git a/backend/common/llm/utils.py b/backend/common/llm/utils.py index c7704b82..ad229825 100644 --- a/backend/common/llm/utils.py +++ b/backend/common/llm/utils.py @@ -13,6 +13,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession from loguru import logger +from extensions.trace.context import get_request_id def parse_llm_json(json_str: str) -> dict: @@ -82,7 +83,7 @@ async def convert_gen_to_stream_chat_completions( enable_output_check = False chunk_index = 0 - chat_id = uuid.uuid4().hex + chat_id = get_request_id() or uuid.uuid4().hex total_usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) citations, citation_details = [], [] @@ -208,7 +209,8 @@ async def convert_gen_to_chat_completions( guardrail_hint: str | None = None, checker: GuardrailChecker | None = None, ): - chat_id = uuid.uuid4().hex + chat_id = get_request_id() or uuid.uuid4().hex + total_usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) reasoning_content = "" diff --git a/backend/extensions/trace/context.py b/backend/extensions/trace/context.py index e16cee69..40488b15 100644 --- a/backend/extensions/trace/context.py +++ b/backend/extensions/trace/context.py @@ -7,6 +7,14 @@ USER_NAME = "gen_ai.user.name" SESSION_ID = "gen_ai.session.id" +# baggage key for request_id from agentscope +AGENTSCOPE_REQUEST_ID_KEY = "traffic.llm_sdk.agentscope.request_id" + +# ContextVar for request_id +_request_id_var: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "request_id", default=None +) + custom_context_vars = {} @@ -25,3 +33,15 @@ def set_context_var(key, value): custom_context_vars[key].set(value) else: logger.warning(f"key: `{key}` is not in custom_context_vars") + + +def set_request_id(request_id: str | None): + """Set the request_id in the current context.""" + _request_id_var.set(request_id) + + +def get_request_id() -> str | None: + """Get the request_id from the current context.""" + request_id = _request_id_var.get() + + return request_id diff --git a/backend/extensions/trace/rag_wrapper.py b/backend/extensions/trace/rag_wrapper.py index ee0e22fa..ce0a941a 100644 --- a/backend/extensions/trace/rag_wrapper.py +++ b/backend/extensions/trace/rag_wrapper.py @@ -77,7 +77,6 @@ async def wrapper(self, *args, **kwargs): "score_threshold": retrieval_setting.similarity_threshold if retrieval_setting else None, } span.set_attribute(INPUT_VALUE, json.dumps(input_data, ensure_ascii=False)) - span.set_attribute(GEN_AI_OPERATION_NAME, RETRIEVER_OPERATION_NAME) results: List[SearchResult] = await func(self, *args, **kwargs) diff --git a/backend/extensions/trace/trace_context_middleware.py b/backend/extensions/trace/trace_context_middleware.py index 9a58deed..ac67f77c 100644 --- a/backend/extensions/trace/trace_context_middleware.py +++ b/backend/extensions/trace/trace_context_middleware.py @@ -4,6 +4,15 @@ from starlette.middleware.base import BaseHTTPMiddleware from opentelemetry.propagate import get_global_textmap from opentelemetry import context +from opentelemetry.baggage import get_baggage +import uuid + + +from extensions.trace.context import ( + AGENTSCOPE_REQUEST_ID_KEY, + set_request_id, +) + class TraceContextMiddleware(BaseHTTPMiddleware): """ @@ -19,7 +28,13 @@ async def dispatch(self, request: Request, call_next): # 在提取的 context 中执行 token = context.attach(extracted_context) try: + # 从 baggage 中获取 request_id 并设置到 context + request_id = get_baggage(AGENTSCOPE_REQUEST_ID_KEY) or uuid.uuid4().hex + set_request_id(request_id) + response = await call_next(request) return response finally: + # 清理 request_id + set_request_id(None) context.detach(token) diff --git a/backend/utils/format_logging.py b/backend/utils/format_logging.py index fb813654..b7c1b45c 100644 --- a/backend/utils/format_logging.py +++ b/backend/utils/format_logging.py @@ -1,7 +1,7 @@ import logging import sys +from extensions.trace.context import get_request_id from loguru import logger -from asgi_correlation_id.context import correlation_id class InterceptHandler(logging.Handler): @@ -28,7 +28,7 @@ def emit(self, record: logging.LogRecord) -> None: # 自定义日志格式,加入 request_id def formatter(record): - record["extra"]["request_id"] = correlation_id.get() + record["extra"]["request_id"] = get_request_id() if record["extra"].get("request_id", None): return ( "{time:YYYY-MM-DD HH:mm:ss.SSS} | "