Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_app():
from app.log_middleware import CustomLoggingMiddleware
from fastapi.middleware.cors import CORSMiddleware
from api.request_validate_exception import validation_exception_handler
from extensions.trace.base import setup_propagator

app = FastAPI(lifespan=lifespan)
add_config_router(app)
Expand All @@ -100,6 +101,7 @@ def create_app():
app.add_middleware(CustomLoggingMiddleware)
app.add_exception_handler(ApiException, api_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
setup_propagator(app)
return app

app = create_app()
1 change: 1 addition & 0 deletions backend/common/tool/search_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@


class SearchResult(BaseModel):
id: str | None = None
title: str | None = None
content: str | None = None
url: str | None = None
Expand Down
4 changes: 2 additions & 2 deletions backend/db/models/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ class TraceModel(SQLModel):
class TraceModelEntity(TraceModel, table=True):
__tablename__ = "pai_trace_config"

id: str = Field(default=lambda x: str(uuid.uuid4().hex), primary_key=True)
id: str = Field(default_factory=lambda x: str(uuid.uuid4().hex), primary_key=True)
def is_enabled(self) -> bool:
return self.enabled and self.service_name and self.token and self.endpoint
return self.enabled and self.service_name and self.endpoint
14 changes: 13 additions & 1 deletion backend/extensions/trace/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,19 @@
from extensions.trace.reloadable_exporter import ReloadableOTLPSpanExporter
from extensions.trace import context as trace_context
from extensions.trace.trace_config import TraceConfig

from opentelemetry.propagate import set_global_textmap
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.propagators.composite import CompositePropagator
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor


def setup_propagator(app):
set_global_textmap(CompositePropagator([
TraceContextTextMapPropagator(), # 处理 traceparent
W3CBaggagePropagator() # 处理 baggage
]))
FastAPIInstrumentor.instrument_app(app)

# trace_provider为singleton, 不支持覆盖,故修改trace配置时,默认覆盖exporter和resource
# 这样如果用户填错密码,还可以成功刷新
Expand Down
1 change: 0 additions & 1 deletion backend/extensions/trace/pai_agent_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from opentelemetry.trace import set_span_in_context
from opentelemetry.trace.status import Status, StatusCode
from openinference.semconv.trace import SpanAttributes, OpenInferenceSpanKindValues

from extensions.trace import context as trace_context
from extensions.trace.utils import pydantic_to_dict

Expand Down
319 changes: 319 additions & 0 deletions backend/extensions/trace/rag_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
from functools import wraps
import json
import os
from typing import List
from enum import Enum
from common.tool.search_result import SearchResult
from opentelemetry.trace.status import Status, StatusCode
from openinference.semconv.trace import SpanAttributes, OpenInferenceSpanKindValues
from opentelemetry import context

from extensions.trace.utils import pydantic_to_dict

from extensions.trace.tracer import get_tracer

GEN_AI_SPAN_KIND = "gen_ai.span.kind"
GEN_AI_OPERATION_NAME = "gen_ai.operation.name"
INPUT_MESSAGES = "gen_ai.input.messages"

INPUT_VALUE = SpanAttributes.INPUT_VALUE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE

RETRIEVER_SPAN_KIND = OpenInferenceSpanKindValues.RETRIEVER.value
EMBEDDING_SPAN_KIND = OpenInferenceSpanKindValues.EMBEDDING.value
RERANKER_SPAN_KIND = OpenInferenceSpanKindValues.RERANKER.value

RETRIEVER_OPERATION_NAME = "retrieve_documents"
EMBEDDING_OPERATION_NAME = "embedding"
RERANKER_OPERATION_NAME = "rerank_documents"

EMBEDDING_MDOEL_NAME = "gen_ai.request.model"
EMBEDDING_DIMENSION_COUNT = "gen_ai.embeddings.dimension.count"

RERANKER_MODEL_NAME = "gen_ai.request.model"

# whether to disable legacy trace and only use agentscope data contract
DISABLE_LEGACY_TRACE = os.getenv("DISABLE_LEGACY_TRACE", "false").lower() in ["true", "1", "yes", "y"]


class RetrieverSpanNames(str, Enum):
KNOWLEDGE_RETRIEVER = "KnowledgeRetriever"
TEXT_RETRIEVER = "TextRetriever"
VECTOR_RETRIEVER = "VectorRetriever"
EMBEDDING = "Embedding"
RERANKER = "Reranker"


STATUS_OK = Status(StatusCode.OK)


def query_knowledgebase_wrapper(func):
"""decorator to capture input & output string of query knowledgebase operation."""

@wraps(func)
async def wrapper(self, *args, **kwargs):
# if not enabled, directly return
if os.getenv("TRACING_ENABLED", "false") != "true":
return await func(self, *args, **kwargs)

query = kwargs.get("query", "[unknown]")
messages = [{
"role": "user",
"content": query,
"metadata": kwargs,
}]
ctx = context.get_current()
with get_tracer().start_as_current_span(RetrieverSpanNames.KNOWLEDGE_RETRIEVER, context=ctx) as span:
try:
span.set_attribute(GEN_AI_SPAN_KIND, RETRIEVER_SPAN_KIND)
span.set_attribute(INPUT_MESSAGES, json.dumps(pydantic_to_dict(messages), ensure_ascii=False))
retrieval_setting = kwargs.get("retrieval_setting", None)
input_data = {
"query": query,
"top_k": retrieval_setting.top_k if retrieval_setting else None,
"score_threshold": retrieval_setting.similarity_threshold if retrieval_setting else None,
}
span.set_attribute(INPUT_VALUE, json.dumps(input_data, ensure_ascii=False))

span.set_attribute(GEN_AI_OPERATION_NAME, RETRIEVER_OPERATION_NAME)

results: List[SearchResult] = await func(self, *args, **kwargs)
output_documents = [
{
"id": doc.id,
"content": doc.content,
"score": doc.score,
"metadata": doc.metadata,
}
for doc in results
]
output_data = {
"documents": output_documents,
"document_size": len(output_documents),
}
output_value = json.dumps(pydantic_to_dict(output_data), ensure_ascii=False)
span.set_attribute(OUTPUT_VALUE, output_value)
span.set_status(STATUS_OK)
return results
except Exception as e:
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
raise

return wrapper


def text_search_wrapper(func):
"""decorator to capture input & output string of query vector store operation."""

@wraps(func)
async def wrapper(*args, **kwargs):
# if not enabled, directly return
if os.getenv("TRACING_ENABLED", "false") != "true":
return await func(*args, **kwargs)

query = kwargs.get("query", "[unknown]")
messages = [{
"role": "user",
"content": query,
}]
with get_tracer().start_as_current_span(RetrieverSpanNames.TEXT_RETRIEVER) as span:
try:
span.set_attribute(GEN_AI_SPAN_KIND, RETRIEVER_SPAN_KIND)
span.set_attribute(INPUT_MESSAGES, json.dumps(pydantic_to_dict(messages), ensure_ascii=False))
input_data = {
"query": query,
"top_k": kwargs.get("top_k", None),
}
span.set_attribute(INPUT_VALUE, json.dumps(input_data, ensure_ascii=False))

span.set_attribute(GEN_AI_OPERATION_NAME, RETRIEVER_OPERATION_NAME)

text_search_result = await func(*args, **kwargs)
output_documents =[]
if text_search_result and text_search_result.nodes:
output_documents = [
{
"id": text_search_result.ids[i],
"content": text_search_result.nodes[i].text,
"score": text_search_result.similarities[i],
"metadata": text_search_result.nodes[i].metadata,
}
for i in range(len(text_search_result.nodes))
]
output_data = {
"documents": output_documents,
"document_size": len(output_documents),
}
output_value = json.dumps(pydantic_to_dict(output_data), ensure_ascii=False)
span.set_attribute(OUTPUT_VALUE, output_value)
span.set_status(STATUS_OK)
return text_search_result
except Exception as e:
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
raise

return wrapper


def vector_search_wrapper(func):
"""decorator to capture input & output string of query vector store operation."""

@wraps(func)
async def wrapper(*args, **kwargs):
# if not enabled, directly return
if os.getenv("TRACING_ENABLED", "false") != "true":
return await func(*args, **kwargs)

query = kwargs.get("query", "[unknown]")
messages = [{
"role": "user",
"content": query,
}]
with get_tracer().start_as_current_span(RetrieverSpanNames.VECTOR_RETRIEVER) as span:
try:
span.set_attribute(GEN_AI_SPAN_KIND, RETRIEVER_SPAN_KIND)
span.set_attribute(INPUT_MESSAGES, json.dumps(pydantic_to_dict(messages), ensure_ascii=False))
input_data = {
"query": query,
"top_k": kwargs.get("top_k", None),
}
span.set_attribute(INPUT_VALUE, json.dumps(input_data, ensure_ascii=False))
span.set_attribute(GEN_AI_OPERATION_NAME, RETRIEVER_OPERATION_NAME)

vector_search_result = await func(*args, **kwargs)
output_documents =[]
if vector_search_result and vector_search_result.nodes:
output_documents = [
{
"id": vector_search_result.ids[i],
"content": vector_search_result.nodes[i].text,
"score": vector_search_result.similarities[i],
"metadata": vector_search_result.nodes[i].metadata,
}
for i in range(len(vector_search_result.nodes))
]
output_data = {
"documents": output_documents,
"document_size": len(output_documents),
}
output_value = json.dumps(pydantic_to_dict(output_data), ensure_ascii=False)
span.set_attribute(OUTPUT_VALUE, output_value)
span.set_status(STATUS_OK)
return vector_search_result
except Exception as e:
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
raise

return wrapper


def embedding_wrapper(func):
"""decorator to capture input & output string of query knowledgebase operation."""

@wraps(func)
async def wrapper(self, *args, **kwargs):
# if not enabled, directly return
if os.getenv("TRACING_ENABLED", "false") != "true":
return await func(self, *args, **kwargs)

query = kwargs.get("query", "[unknown]")
embedding_model_entity = kwargs.get("embedding_model_entity", None)
messages = [{
"role": "user",
"content": query,
}]
try:
span = get_tracer().start_span(RetrieverSpanNames.EMBEDDING)
span.set_attribute(GEN_AI_SPAN_KIND, EMBEDDING_SPAN_KIND)
span.set_attribute(INPUT_MESSAGES, json.dumps(pydantic_to_dict(messages), ensure_ascii=False))
span.set_attribute(EMBEDDING_MDOEL_NAME, embedding_model_entity.model_name)
span.set_attribute(INPUT_VALUE, query)

span.set_attribute(GEN_AI_OPERATION_NAME, EMBEDDING_OPERATION_NAME)

query_embedding = await func(self, *args, **kwargs)
span.set_attribute(EMBEDDING_DIMENSION_COUNT, len(query_embedding))
span.set_status(STATUS_OK)
span.end()
return query_embedding
except Exception as e:
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
span.end()
raise

return wrapper


def reranker_wrapper(func):
"""decorator to capture input & output string of reranker operation."""

@wraps(func)
async def wrapper(self, *args, **kwargs):
# if not enabled, directly return
if os.getenv("TRACING_ENABLED", "false") != "true":
return await func(self, *args, **kwargs)

query = kwargs.get("query", "[unknown]")
vector_result = kwargs.get("vector_result", None)
top_n = kwargs.get("top_n", None)
try:
span = get_tracer().start_span(RetrieverSpanNames.RERANKER)
span.set_attribute(GEN_AI_SPAN_KIND, RERANKER_SPAN_KIND)
span.set_attribute(RERANKER_MODEL_NAME, self.model)
messages = [{
"role": "user",
"content": query,
}]
span.set_attribute(INPUT_MESSAGES, json.dumps(pydantic_to_dict(messages), ensure_ascii=False))

if vector_result and vector_result.nodes:
input_documents = [
{
"id": vector_result.ids[i],
"content": vector_result.nodes[i].text,
"score": vector_result.similarities[i],
"metadata": vector_result.nodes[i].metadata,
}
for i in range(len(vector_result.nodes))
]
else:
input_documents = []

input_value = {
"documents": input_documents,
"query": query,
"document_size": len(input_documents),
"top_k": top_n,
}
span.set_attribute(INPUT_VALUE, json.dumps(pydantic_to_dict(input_value), ensure_ascii=False))

rerank_result = await func(self, *args, **kwargs)
output_documents = []
if rerank_result and rerank_result.nodes:
output_documents = [
{
"id": rerank_result.ids[i],
"content": rerank_result.nodes[i].text,
"score": rerank_result.similarities[i],
"metadata": rerank_result.nodes[i].metadata,
}
for i in range(len(rerank_result.nodes))
]
output_value = {
"documents": output_documents
}
span.set_attribute(OUTPUT_VALUE, json.dumps(pydantic_to_dict(output_value), ensure_ascii=False))
span.set_status(STATUS_OK)
span.end()
return rerank_result
except Exception as e:
span.record_exception(e)
span.set_status(Status(StatusCode.ERROR, str(e)))
span.end()
raise

return wrapper
Loading
Loading