Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions backend/api/v1/mcp/kb_retriever_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,19 @@ class NodeMetadata(BaseModel):
file_path: str
image_url: List[str]
title: str
content: str
doc_name: str

class NodeResult(BaseModel):
score: float
class NodeInfo(BaseModel):
metadata: NodeMetadata
text: str = Field(description="The text content of the node")


class NodeResult(BaseModel):
score: float
node: NodeInfo


class RetrievalToolResponse(BaseModel):
status: str
status_code: int
Expand Down Expand Up @@ -80,13 +85,16 @@ async def asearch_knowledgebase(
images = score_node.images or []
node = NodeResult(
score=score_node.score,
metadata=NodeMetadata(
file_path=file_path or "",
image_url=[img.get("url", "") for img in images if isinstance(img, dict) and img.get("url", "")],
title=title or "",
doc_name=doc_name or "",
),
text=score_node.content,
node=NodeInfo(
text=score_node.content,
metadata=NodeMetadata(
file_path=file_path or "",
image_url=[img.get("url", "") for img in images if isinstance(img, dict) and img.get("url", "")],
title=title or "",
doc_name=doc_name or "",
content=score_node.content,
),
)
)
nodes.append(node)

Expand Down
6 changes: 4 additions & 2 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# init trace
from dotenv import load_dotenv
import os

# don't trace for elasticsearch client
os.environ["OTEL_PYTHON_INSTRUMENTATION_ELASTICSEARCH_ENABLED"] = "false"
load_dotenv()

import os
import asyncio
from fastapi import FastAPI
import threading
# setup models
Expand Down
2 changes: 1 addition & 1 deletion backend/extensions/trace/grpc_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
Union[TypingSequence[Tuple[str, str]], Dict[str, str], str]
] = None,
timeout: Optional[int] = None,
compression: Optional[Compression] = None,
compression: Optional[Compression] = Compression.Gzip,
):
if insecure is None:
insecure = environ.get(OTEL_EXPORTER_OTLP_TRACES_INSECURE)
Expand Down
4 changes: 3 additions & 1 deletion backend/extensions/trace/http_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
OTEL_EXPORTER_OTLP_TRACES_ENDPOINT,
OTEL_EXPORTER_OTLP_TRACES_HEADERS,
)
from opentelemetry.exporter.otlp.proto.http import Compression
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.environment_variables import (
OTEL_EXPORTER_OTLP_HEADERS,
Expand All @@ -38,7 +39,7 @@ class ReloadableHttpOTLPSpanExporter(OTLPSpanExporter):
credentials: Credentials object for server authentication
headers: Headers to send when exporting
timeout: Backend request timeout in seconds
compression: gRPC compression method to use
compression: compression method to use
"""
def __init__(
self,
Expand All @@ -62,6 +63,7 @@ def __init__(
endpoint=endpoint,
headers=headers,
timeout=timeout,
compression=Compression.Gzip,
)

def reload(
Expand Down
7 changes: 6 additions & 1 deletion backend/extensions/trace/trace_context_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from opentelemetry import context
from opentelemetry.baggage import get_baggage
import uuid

import os
from loguru import logger

from extensions.trace.context import (
AGENTSCOPE_REQUEST_ID_KEY,
set_request_id,
)

ENABLE_TRACE_CONTEXT_DEBUG = os.getenv("ENABLE_TRACE_CONTEXT_DEBUG", "false").lower() in ["true", "1", "yes", "y"]


class TraceContextMiddleware(BaseHTTPMiddleware):
"""
Expand All @@ -22,6 +25,8 @@ class TraceContextMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# 提取 trace context
carrier = dict(request.headers)
if ENABLE_TRACE_CONTEXT_DEBUG:
logger.info(f"Trace context debug headers: {carrier}")
propagator = get_global_textmap()
extracted_context = propagator.extract(carrier=carrier)

Expand Down
1 change: 1 addition & 0 deletions backend/service/knowledgebase/rag_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ async def format_search_result(self, reranked_result: VectorStoreQueryResult, te
@embedding_wrapper
async def embed_query(self, query: str, embedding_model_entity: EmbeddingModelEntity) -> List[float]:
embed_model = create_embedding_model(embedding_model_entity)
# skip tracing for embedding vectors
query_embedding = await embed_model.aget_query_embedding(query)
return query_embedding

Expand Down
19 changes: 13 additions & 6 deletions backend/tools/utils/vectordb_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ async def _aquery_text(
query_kwargs["filters"] = metadata_filters
elif document_ids:
query_kwargs["doc_ids"] = document_ids
return await vector_store.aquery(VectorStoreQuery(**query_kwargs))
text_result = await vector_store.aquery(VectorStoreQuery(**query_kwargs))
# TEXT_SEARCH 模式的分数归一化
if text_result and text_result.similarities:
text_result.similarities = min_max_normalize_scores(text_result.similarities)
if text_result and text_result.nodes:
for node in text_result.nodes:
node.metadata.pop("page_bbox", None)
return text_result


@vector_search_wrapper
Expand All @@ -56,7 +63,11 @@ async def _aquery_vector(
query_kwargs["filters"] = metadata_filters
elif document_ids:
query_kwargs["doc_ids"] = document_ids
return await vector_store.aquery(VectorStoreQuery(**query_kwargs))
dense_result = await vector_store.aquery(VectorStoreQuery(**query_kwargs))
if dense_result and dense_result.nodes:
for node in dense_result.nodes:
node.metadata.pop("page_bbox", None)
return dense_result


async def aquery_vector_store(
Expand Down Expand Up @@ -133,10 +144,6 @@ async def aquery_vector_store(
metadata_filters=metadata_filters,
)
logger.info(f"{query_mode} mode: Retrieved {len(dense_result.nodes)} nodes.")

# TEXT_SEARCH 模式的分数归一化
if text_result and text_result.similarities:
text_result.similarities = min_max_normalize_scores(text_result.similarities)
except Exception as e:
logger.error(f"Failed to query vector store: {e}")
raise
Expand Down
Loading