diff --git a/src/uipath_llamaindex/_cli/_tracing/_attribute_normalizer.py b/src/uipath_llamaindex/_cli/_tracing/_attribute_normalizer.py new file mode 100644 index 0000000..9638265 --- /dev/null +++ b/src/uipath_llamaindex/_cli/_tracing/_attribute_normalizer.py @@ -0,0 +1,87 @@ +"""OpenTelemetry SpanProcessor for normalizing LlamaIndex tool call attributes. + +LlamaIndex wraps tool arguments in {"kwargs": {...}} which differs from other +frameworks like LangChain that use flat {"arg": value} format. This processor +normalizes the format at the span level before exporters or dev terminal read it. +""" + +import json +import logging +from typing import Any, Optional + +from opentelemetry.context import Context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor + +logger = logging.getLogger(__name__) + + +class AttributeNormalizingSpanProcessor(SpanProcessor): + """Normalizes LlamaIndex tool call attributes to match other frameworks. + + Unwraps {"kwargs": {...}} to flat {...} format for consistency with LangChain. + """ + + def on_start(self, span: Span, parent_context: Optional[Context] = None) -> None: + """Called when span starts - no action needed.""" + pass + + def on_end(self, span: ReadableSpan) -> None: + """Normalize tool call attributes before span is consumed by exporters/terminal.""" + if not span._attributes: + return + + try: + # Get the mutable internal attributes dict + attrs: dict = span._attributes # type: ignore[attr-defined] + + if attrs.get("openinference.span.kind", None) == "TOOL": + # Normalize tool call attributes + for key in ("input.value", "output.value"): + if key in attrs: + original = attrs[key] + normalized = self._normalize_attribute(key, original) + + if normalized != original: + attrs[key] = normalized + if logger.isEnabledFor(logging.DEBUG): + logger.debug( + f"Normalized {key} in span '{span.name}': " + f"{original[:50]}... → {normalized[:50]}..." + ) + + except Exception as e: + # Don't crash span processing if normalization fails + logger.debug( + f"Failed to normalize span '{getattr(span, 'name', 'unknown')}': {e}" + ) + + def _normalize_attribute(self, key: str, value: Any) -> str: + """Unwrap LlamaIndex's kwargs wrapper if present.""" + if isinstance(value, str): + try: + value = json.loads(value) + except Exception: + pass + if isinstance(value, dict): + if key == "input.value": + if "kwargs" in value: + value = json.dumps(value["kwargs"]) + elif key == "output.value": + value = json.dumps( + { + "content": value.get("raw_output"), + "status": "success" + if not value.get("is_error", False) + else "error", + "tool_call_id": value.get("tool_call_id"), + } + ) + return str(value) + + def shutdown(self) -> None: + """Called on processor shutdown - no cleanup needed.""" + pass + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """Force flush - always succeeds (nothing to flush).""" + return True diff --git a/src/uipath_llamaindex/_cli/cli_dev.py b/src/uipath_llamaindex/_cli/cli_dev.py index 2d5deda..dc695d9 100644 --- a/src/uipath_llamaindex/_cli/cli_dev.py +++ b/src/uipath_llamaindex/_cli/cli_dev.py @@ -12,6 +12,7 @@ from ._runtime._context import UiPathLlamaIndexRuntimeContext from ._runtime._runtime import UiPathLlamaIndexRuntime +from ._tracing._attribute_normalizer import AttributeNormalizingSpanProcessor console = ConsoleLogger() @@ -24,6 +25,9 @@ def llamaindex_dev_middleware(interface: Optional[str]) -> MiddlewareResult: runtime_factory = UiPathRuntimeFactory( UiPathLlamaIndexRuntime, UiPathLlamaIndexRuntimeContext ) + runtime_factory.tracer_provider.add_span_processor( + AttributeNormalizingSpanProcessor() + ) runtime_factory.add_instrumentor(LlamaIndexInstrumentor, get_current_span) app = UiPathDevTerminal(runtime_factory) asyncio.run(app.run_async()) diff --git a/src/uipath_llamaindex/_cli/cli_run.py b/src/uipath_llamaindex/_cli/cli_run.py index 3a40935..e525038 100644 --- a/src/uipath_llamaindex/_cli/cli_run.py +++ b/src/uipath_llamaindex/_cli/cli_run.py @@ -1,7 +1,7 @@ import asyncio import logging from os import environ as env -from typing import Optional +from typing import Any, Optional from openinference.instrumentation.llama_index import ( LlamaIndexInstrumentor, @@ -13,6 +13,7 @@ from ._runtime._context import UiPathLlamaIndexRuntimeContext from ._runtime._exception import UiPathLlamaIndexRuntimeError from ._runtime._runtime import UiPathLlamaIndexRuntime +from ._tracing._attribute_normalizer import AttributeNormalizingSpanProcessor from ._tracing._oteladapter import LlamaIndexExporter from ._utils._config import LlamaIndexConfig @@ -20,7 +21,7 @@ def llamaindex_run_middleware( - entrypoint: Optional[str], input: Optional[str], resume: bool, **kwargs + entrypoint: Optional[str], input: Optional[str], resume: bool, **kwargs: Any ) -> MiddlewareResult: """Middleware to handle LlamaIndex agent execution""" @@ -67,6 +68,9 @@ async def execute(): UiPathLlamaIndexRuntime, UiPathLlamaIndexRuntimeContext ) + runtime_factory.tracer_provider.add_span_processor( + AttributeNormalizingSpanProcessor() + ) if context.job_id: runtime_factory.add_span_exporter(LlamaIndexExporter())