Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions backend/api/v1/mcp/kb_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@ class NodeMetadata(BaseModel):
title: str
doc_name: str

class NodeResult(BaseModel):
score: float
class NodeInfo(BaseModel):
metadata: NodeMetadata
text: str = Field(description="The text content of the node")


class NodeResult(BaseModel):
score: float
node: NodeInfo


class RetrievalToolResponse(BaseModel):
status: str
status_code: int
Expand Down Expand Up @@ -80,13 +84,15 @@ async def asearch_knowledgebase(
images = score_node.images or []
node = NodeResult(
score=score_node.score,
metadata=NodeMetadata(
file_path=file_path or "",
image_url=[img.get("url", "") for img in images if isinstance(img, dict) and img.get("url", "")],
title=title or "",
doc_name=doc_name or "",
),
text=score_node.content,
node=NodeInfo(
text=score_node.content,
metadata=NodeMetadata(
file_path=file_path or "",
image_url=[img.get("url", "") for img in images if isinstance(img, dict) and img.get("url", "")],
title=title or "",
doc_name=doc_name or "",
),
)
)
nodes.append(node)

Expand Down
6 changes: 4 additions & 2 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# init trace
from dotenv import load_dotenv
import os

# don't trace for elasticsearch client
os.environ["OTEL_PYTHON_INSTRUMENTATION_ELASTICSEARCH_ENABLED"] = "false"
load_dotenv()

import os
import asyncio
from fastapi import FastAPI
import threading
# setup models
Expand Down
2 changes: 1 addition & 1 deletion backend/extensions/trace/grpc_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
Union[TypingSequence[Tuple[str, str]], Dict[str, str], str]
] = None,
timeout: Optional[int] = None,
compression: Optional[Compression] = None,
compression: Optional[Compression] = Compression.Gzip,
):
if insecure is None:
insecure = environ.get(OTEL_EXPORTER_OTLP_TRACES_INSECURE)
Expand Down
4 changes: 3 additions & 1 deletion backend/extensions/trace/http_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OTEL_EXPORTER_OTLP_TRACES_ENDPOINT,
OTEL_EXPORTER_OTLP_TRACES_HEADERS,
)
from opentelemetry.exporter.otlp.proto.http import Compression
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_HEADERS,
Expand All @@ -38,7 +39,7 @@ class ReloadableHttpOTLPSpanExporter(OTLPSpanExporter):
credentials: Credentials object for server authentication
headers: Headers to send when exporting
timeout: Backend request timeout in seconds
compression: gRPC compression method to use
compression: compression method to use
"""
def __init__(
self,
Expand All @@ -62,6 +63,7 @@ def __init__(
endpoint=endpoint,
headers=headers,
timeout=timeout,
compression=Compression.Gzip,
)

def reload(
Expand Down
7 changes: 6 additions & 1 deletion backend/extensions/trace/trace_context_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from opentelemetry import context
from opentelemetry.baggage import get_baggage
import uuid

import os
from loguru import logger

from extensions.trace.context import (
AGENTSCOPE_REQUEST_ID_KEY,
set_request_id,
)

ENABLE_TRACE_CONTEXT_DEBUG = os.getenv("ENABLE_TRACE_CONTEXT_DEBUG", "false").lower() in ["true", "1", "yes", "y"]


class TraceContextMiddleware(BaseHTTPMiddleware):
"""
Expand All @@ -22,6 +25,8 @@ class TraceContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 提取 trace context
carrier = dict(request.headers)
if ENABLE_TRACE_CONTEXT_DEBUG:
logger.info(f"Trace context debug headers: {carrier}")
propagator = get_global_textmap()
extracted_context = propagator.extract(carrier=carrier)

Expand Down
7 changes: 5 additions & 2 deletions backend/service/knowledgebase/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from db.db_context import create_db_session
from common.knowledgebase.constants import DEFAULT_VECTOR_WEIGHT, DEFAULT_SIMILARITY_TOP_K, DEFAULT_RERANK_SIMILARITY_TOP_K
from extensions.trace.rag_wrapper import query_knowledgebase_wrapper, embedding_wrapper
from openinference.instrumentation import suppress_tracing
from loguru import logger

MARKDOWN_IMAGE_PATTERN = r'!\[.*?\]\((.*?)\)\s*\n*\s*图片的描述:\s*(.*?)(?=\n\n|$)'
Expand Down Expand Up @@ -832,8 +833,10 @@ async def format_search_result(self, reranked_result: VectorStoreQueryResult, te
@embedding_wrapper
async def embed_query(self, query: str, embedding_model_entity: EmbeddingModelEntity) -> List[float]:
embed_model = create_embedding_model(embedding_model_entity)
query_embedding = await embed_model.aget_query_embedding(query)
return query_embedding
# skip tracing for embedding vectors
with suppress_tracing():
query_embedding = await embed_model.aget_query_embedding(query)
return query_embedding

# 当需要发起SessionScope并发时,每个查询都需要独立的session实例
async def _aquery_task(
Expand Down
19 changes: 13 additions & 6 deletions backend/tools/utils/vectordb_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ async def _aquery_text(
query_kwargs["filters"] = metadata_filters
elif document_ids:
query_kwargs["doc_ids"] = document_ids
return await vector_store.aquery(VectorStoreQuery(**query_kwargs))
text_result = await vector_store.aquery(VectorStoreQuery(**query_kwargs))
# TEXT_SEARCH 模式的分数归一化
if text_result and text_result.similarities:
text_result.similarities = min_max_normalize_scores(text_result.similarities)
if text_result and text_result.nodes:
for node in text_result.nodes:
node.metadata.pop("page_bbox", None)
return text_result


@vector_search_wrapper
Expand All @@ -56,7 +63,11 @@ async def _aquery_vector(
query_kwargs["filters"] = metadata_filters
elif document_ids:
query_kwargs["doc_ids"] = document_ids
return await vector_store.aquery(VectorStoreQuery(**query_kwargs))
dense_result = await vector_store.aquery(VectorStoreQuery(**query_kwargs))
if dense_result and dense_result.nodes:
for node in dense_result.nodes:
node.metadata.pop("page_bbox", None)
return dense_result


async def aquery_vector_store(
Expand Down Expand Up @@ -133,10 +144,6 @@ async def aquery_vector_store(
metadata_filters=metadata_filters,
)
logger.info(f"{query_mode} mode: Retrieved {len(dense_result.nodes)} nodes.")

# TEXT_SEARCH 模式的分数归一化
if text_result and text_result.similarities:
text_result.similarities = min_max_normalize_scores(text_result.similarities)
except Exception as e:
logger.error(f"Failed to query vector store: {e}")
raise
Expand Down
Loading