Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 7 additions & 17 deletions backend/api/v1/mcp/kb_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import traceback
import uuid
from typing import List, Optional
from loguru import logger
from common.chat.models import RetrievalSetting
Expand All @@ -25,16 +24,12 @@ class NodeResult(BaseModel):
text: str = Field(description="The text content of the node")


class RetrievalResult(BaseModel):
total: int
nodes: List[NodeResult] = []


class RetrievalToolResponse(BaseModel):
status: str
status_code: int
message: Optional[str] = None
data: RetrievalResult
total: int
nodes: List[NodeResult] = []
request_id: str


Expand All @@ -47,9 +42,8 @@ async def asearch_knowledgebase(
retrieval_setting: Optional[RetrievalSetting] = None,
metadata_condition: Optional[MetadataFilteringCondition] = None,
rag_service: RagService = None,
request_id: str = None,
) -> RetrievalToolResponse:
request_id = str(uuid.uuid4())

# TODO: Handle images if needed in the future
# For now, images are accepted but not used in retrieval
if image_list:
Expand Down Expand Up @@ -99,10 +93,8 @@ async def asearch_knowledgebase(
return RetrievalToolResponse(
status="SUCCESS",
status_code=200,
data=RetrievalResult(
total=len(nodes),
nodes=nodes
),
total=len(nodes),
nodes=nodes,
message=None,
request_id=request_id
)
Expand All @@ -111,10 +103,8 @@ async def asearch_knowledgebase(
return RetrievalToolResponse(
status="ERROR",
status_code=500,
data=RetrievalResult(
total=0,
nodes=[]
),
total=0,
nodes=[],
message=str(ex),
request_id=request_id
)
Expand Down
14 changes: 8 additions & 6 deletions backend/api/v1/retrieval.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from fastapi import APIRouter, Depends
from common.chat.response_model import ResponseModel
from api.api_exception import ApiException
from db.db_context import get_db_session
from common.chat.models import DocRecord, NewRetrievalResponse, RetrievalRequest
from sqlmodel.ext.asyncio.session import AsyncSession
from service.injection import get_rag_service, get_tenant_id
from service.knowledgebase.rag_service import RagService
from typing import List
from common.tool.search_result import SearchResult
from fastapi.responses import JSONResponse
import traceback
from extensions.trace.context import get_request_id
from loguru import logger


Expand All @@ -20,11 +19,11 @@
)
async def retrieval(
retrieval_request: RetrievalRequest,
session: AsyncSession = Depends(get_db_session),
tenant_id: str = Depends(get_tenant_id),
rag_service: RagService = Depends(get_rag_service),
):
logger.info(f"Retrieval request: {retrieval_request}, tenant_id: {tenant_id}")
request_id = get_request_id()
logger.info(f"Retrieval request: {retrieval_request}, tenant_id: {tenant_id}, request_id: {request_id}")
try:
search_results: List[SearchResult] = await rag_service.aquery(
query=retrieval_request.query,
Expand All @@ -36,7 +35,7 @@ async def retrieval(
tenant_id=tenant_id,
)
logger.info(
f"Retrieved {len(search_results)} for query '{retrieval_request.query}'."
f"Retrieved {len(search_results)} for query '{retrieval_request.query}', request_id: {request_id}"
)
records = []
for node in search_results:
Expand All @@ -48,7 +47,10 @@ async def retrieval(
))
# 使用统一的响应格式
retrieval_response = NewRetrievalResponse(records=records)
return JSONResponse(status_code=200, content=retrieval_response.model_dump())
retrieval_response_data = retrieval_response.model_dump()
logger.info(f"Retrieval response: {retrieval_response_data}, request_id: {request_id}")

return JSONResponse(status_code=200, content=retrieval_response_data)
except ValueError as e:
logger.error(f"Failed to retrieve: {traceback.format_exc()}")
raise ApiException(code=400, message=f"Failed to retrieve: {e}")
Expand Down
12 changes: 7 additions & 5 deletions backend/api/v1/retrieval_tool_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import List, Optional
from sqlmodel.ext.asyncio.session import AsyncSession
from db.db_context import get_db_session
from common.chat.models import RetrievalSetting
from api.v1.mcp.kb_retriever_tool import asearch_knowledgebase
from common.chat.models import MetadataFilteringCondition
from service.knowledgebase.rag_service import RagService
from service.injection import get_rag_service, get_tenant_id
from loguru import logger
from extensions.trace.context import get_request_id

retrieval_tool_router = APIRouter()

Expand All @@ -27,15 +26,15 @@ async def mcp_retrieval(
knowledgebase_id: str,
request: RetrievalToolRequest,
tenant_id: str = Depends(get_tenant_id),
session: AsyncSession = Depends(get_db_session),
rag_service: RagService = Depends(get_rag_service),
):
"""
Retrieval tool interface with different input/output format.
Input: {"query": "xxx", "images": ["1.jpg", "2.jpg"]}
Output: {"status": "SUCCESS", "status_code": 200, "data": {"total": 2, "nodes": [...]}, "request_id": "..."}
"""
logger.info(f"Retrieval tool request: {request}, knowledgebase_id: {knowledgebase_id}, tenant_id: {tenant_id}")
request_id = get_request_id()
logger.info(f"Retrieval tool request: {request}, knowledgebase_id: {knowledgebase_id}, tenant_id: {tenant_id}, request_id: {request_id}")

result = await asearch_knowledgebase(
query=request.query,
Expand All @@ -45,5 +44,8 @@ async def mcp_retrieval(
metadata_condition=request.metadata_condition,
rag_service=rag_service,
tenant_id=tenant_id,
request_id=request_id,
)
return JSONResponse(status_code=result.status_code, content=result.model_dump())
result_data = result.model_dump()
logger.info(f"Retrieval tool response: {result_data}, request_id: {request_id}")
return JSONResponse(status_code=result.status_code, content=result_data)
6 changes: 4 additions & 2 deletions backend/common/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlmodel.ext.asyncio.session import AsyncSession

from loguru import logger
from extensions.trace.context import get_request_id


def parse_llm_json(json_str: str) -> dict:
Expand Down Expand Up @@ -82,7 +83,7 @@ async def convert_gen_to_stream_chat_completions(
enable_output_check = False

chunk_index = 0
chat_id = uuid.uuid4().hex
chat_id = get_request_id() or uuid.uuid4().hex
total_usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
citations, citation_details = [], []

Expand Down Expand Up @@ -208,7 +209,8 @@ async def convert_gen_to_chat_completions(
guardrail_hint: str | None = None,
checker: GuardrailChecker | None = None,
):
chat_id = uuid.uuid4().hex
chat_id = get_request_id() or uuid.uuid4().hex

total_usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)

reasoning_content = ""
Expand Down
20 changes: 20 additions & 0 deletions backend/extensions/trace/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
USER_NAME = "gen_ai.user.name"
SESSION_ID = "gen_ai.session.id"

# baggage key for request_id from agentscope
AGENTSCOPE_REQUEST_ID_KEY = "traffic.llm_sdk.agentscope.request_id"

# ContextVar for request_id
_request_id_var: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"request_id", default=None
)


custom_context_vars = {}

Expand All @@ -25,3 +33,15 @@ def set_context_var(key, value):
custom_context_vars[key].set(value)
else:
logger.warning(f"key: `{key}` is not in custom_context_vars")


def set_request_id(request_id: str | None):
"""Set the request_id in the current context."""
_request_id_var.set(request_id)


def get_request_id() -> str | None:
"""Get the request_id from the current context."""
request_id = _request_id_var.get()

return request_id
1 change: 0 additions & 1 deletion backend/extensions/trace/rag_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ async def wrapper(self, *args, **kwargs):
"score_threshold": retrieval_setting.similarity_threshold if retrieval_setting else None,
}
span.set_attribute(INPUT_VALUE, json.dumps(input_data, ensure_ascii=False))

span.set_attribute(GEN_AI_OPERATION_NAME, RETRIEVER_OPERATION_NAME)

results: List[SearchResult] = await func(self, *args, **kwargs)
Expand Down
15 changes: 15 additions & 0 deletions backend/extensions/trace/trace_context_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from starlette.middleware.base import BaseHTTPMiddleware
from opentelemetry.propagate import get_global_textmap
from opentelemetry import context
from opentelemetry.baggage import get_baggage
import uuid


from extensions.trace.context import (
AGENTSCOPE_REQUEST_ID_KEY,
set_request_id,
)


class TraceContextMiddleware(BaseHTTPMiddleware):
"""
Expand All @@ -19,7 +28,13 @@ async def dispatch(self, request: Request, call_next):
# 在提取的 context 中执行
token = context.attach(extracted_context)
try:
# 从 baggage 中获取 request_id 并设置到 context
request_id = get_baggage(AGENTSCOPE_REQUEST_ID_KEY) or uuid.uuid4().hex
set_request_id(request_id)

response = await call_next(request)
return response
finally:
# 清理 request_id
set_request_id(None)
context.detach(token)
4 changes: 2 additions & 2 deletions backend/utils/format_logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from extensions.trace.context import get_request_id
from loguru import logger
from asgi_correlation_id.context import correlation_id


class InterceptHandler(logging.Handler):
Expand All @@ -28,7 +28,7 @@ def emit(self, record: logging.LogRecord) -> None:

# 自定义日志格式,加入 request_id
def formatter(record):
record["extra"]["request_id"] = correlation_id.get()
record["extra"]["request_id"] = get_request_id()
if record["extra"].get("request_id", None):
return (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
Expand Down
Loading