diff --git a/python/instrumentation/openinference-instrumentation-autogen-agentchat/src/openinference/instrumentation/autogen_agentchat/_wrappers.py b/python/instrumentation/openinference-instrumentation-autogen-agentchat/src/openinference/instrumentation/autogen_agentchat/_wrappers.py index 6e14741ef7..a84b391957 100644 --- a/python/instrumentation/openinference-instrumentation-autogen-agentchat/src/openinference/instrumentation/autogen_agentchat/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-autogen-agentchat/src/openinference/instrumentation/autogen_agentchat/_wrappers.py @@ -36,6 +36,8 @@ from openinference.instrumentation import ( get_attributes_from_context, get_output_attributes, + get_llm_model_name_attributes, + get_llm_token_count_attributes, safe_json_dumps, ) from openinference.semconv.trace import ( @@ -48,6 +50,7 @@ ToolAttributes, ToolCallAttributes, ) +from openinference.semconv.trace import SpanAttributes as LLM logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -97,6 +100,80 @@ def _flatten(mapping: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Attrib yield key, value +def _get_token_value(obj: Any, attr_name: str) -> Optional[int]: + """Extract token value from object attribute.""" + if hasattr(obj, attr_name): + value = getattr(obj, attr_name) + if value is not None: + return value + return None + + +def _extract_details_from_object(details: Any, mapping: Dict[str, str]) -> Dict[str, int]: + """Extract details from object attributes or dict.""" + result = {} + if isinstance(details, dict): + for key, attr_name in mapping.items(): + value = details.get(attr_name) + if value is not None: + result[key] = value + else: + for key, attr_name in mapping.items(): + value = _get_token_value(details, attr_name) + if value is not None: + result[key] = value + return result + + +def _extract_token_usage(result: Any) -> Dict[str, Any]: + """Extract token usage information from CreateResult.""" + if not hasattr(result, 'usage') or not result.usage: + return {} + + usage = result.usage + token_usage = {} + + prompt_tokens = _get_token_value(usage, 'prompt_tokens') + if prompt_tokens is not None: + token_usage["prompt"] = prompt_tokens + + completion_tokens = _get_token_value(usage, 'completion_tokens') + if completion_tokens is not None: + token_usage["completion"] = completion_tokens + + details = None + if hasattr(usage, 'prompt_tokens_details'): + details = usage.prompt_tokens_details + elif hasattr(usage, '__dict__') and 'prompt_tokens_details' in usage.__dict__: + details = usage.__dict__['prompt_tokens_details'] + + if details: + prompt_details = _extract_details_from_object( + details, + {"cache_read": "cached_tokens", "audio": "audio_tokens", "cache_input": "text_tokens"} + ) + if prompt_details: + token_usage["prompt_details"] = prompt_details + + details = None + if hasattr(usage, 'completion_tokens_details'): + details = usage.completion_tokens_details + elif hasattr(usage, '__dict__') and 'completion_tokens_details' in usage.__dict__: + details = usage.__dict__['completion_tokens_details'] + + if details: + completion_details = _extract_details_from_object( + details, + {"reasoning": "reasoning_tokens", "audio": "audio_tokens"} + ) + if completion_details: + token_usage["completion_details"] = completion_details + + if "prompt" in token_usage and "completion" in token_usage: + token_usage["total"] = token_usage["prompt"] + token_usage["completion"] + + return token_usage + def _get_input_value(method: Callable[..., Any], *args: Any, **kwargs: Any) -> str: """ Parses a method call's inputs into a JSON string. Ensures a consistent @@ -466,6 +543,9 @@ async def __call__( if param_name in method_signature.parameters: valid_kwargs[param_name] = param_value + # Extract model name from instance + model_name = getattr(instance, 'model', None) + with self._tracer.start_as_current_span( span_name, attributes=dict( @@ -474,6 +554,7 @@ async def __call__( OPENINFERENCE_SPAN_KIND: LLM, **dict(_llm_messages_attributes(messages, "input")), **dict(_get_llm_tool_attributes(tools)), + **dict(get_llm_model_name_attributes(model_name)), **dict(get_attributes_from_context()), } ) @@ -488,13 +569,25 @@ async def __call__( # Extract output attributes and process tool calls in response output_attributes = dict(get_output_attributes(result)) tool_call_attributes = dict(_extract_output_tool_calls(result)) - + + # Extract token usage information + token_usage = _extract_token_usage(result) + token_attributes = dict(get_llm_token_count_attributes(token_usage)) + + # Add completion_details if available (reasoning_tokens, audio_tokens) + if completion_details := token_usage.get("completion_details"): + if isinstance(completion_details, dict): + if reasoning_tokens := completion_details.get("reasoning"): + token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING] = reasoning_tokens + if audio_tokens := completion_details.get("audio"): + token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO] = audio_tokens span.set_attributes( dict( _flatten( { **output_attributes, **tool_call_attributes, + **token_attributes, } ) ) @@ -535,6 +628,28 @@ async def __call__( for param_name, param_value in kwargs.items(): if param_name in method_signature.parameters: valid_kwargs[param_name] = param_value + + # Ensure stream_options includes include_usage for token metrics + if valid_kwargs.get("include_usage") is None: + extra_create_args = valid_kwargs.get("extra_create_args", {}) + if isinstance(extra_create_args, dict) and "extra_create_args" in method_signature.parameters: + stream_options = extra_create_args.get("stream_options", {}) + if not stream_options.get("include_usage"): + # Inject stream_options to ensure we get usage data + if "extra_create_args" not in valid_kwargs: + valid_kwargs["extra_create_args"] = {} + extra_create_args = valid_kwargs["extra_create_args"] + if not isinstance(extra_create_args, dict): + valid_kwargs["extra_create_args"] = dict(extra_create_args) + extra_create_args = valid_kwargs["extra_create_args"] + if "stream_options" not in extra_create_args: + extra_create_args["stream_options"] = {} + extra_create_args["stream_options"]["include_usage"] = True + elif "include_usage" in method_signature.parameters: + # Set include_usage parameter directly if supported + valid_kwargs["include_usage"] = True + + model_name = getattr(instance, 'model', None) span_name = f"{instance.__class__.__name__}.create_stream" with self._tracer.start_as_current_span( @@ -545,6 +660,7 @@ async def __call__( OPENINFERENCE_SPAN_KIND: LLM, **dict(_llm_messages_attributes(messages, "input")), **dict(_get_llm_tool_attributes(tools)), + **dict(get_llm_model_name_attributes(model_name)), **dict(get_attributes_from_context()), } ) @@ -552,24 +668,52 @@ async def __call__( record_exception=False, set_status_on_exception=False, ) as span: + # Track the latest token usage and output attributes + latest_token_attributes = {} + latest_output_attributes = {} + latest_tool_call_attributes = {} + try: async for res in wrapped(*args, **valid_kwargs): if isinstance(res, CreateResult): # Extract output attributes and process tool calls in response output_attributes = dict(get_output_attributes(res)) tool_call_attributes = dict(_extract_output_tool_calls(res)) - - span.set_attributes( - dict( - _flatten( - { - **output_attributes, - **tool_call_attributes, - } - ) + + # Extract token usage information + token_usage = _extract_token_usage(res) + token_attributes = dict(get_llm_token_count_attributes(token_usage)) + + # Add completion_details if available (reasoning_tokens, audio_tokens) + if completion_details := token_usage.get("completion_details"): + if isinstance(completion_details, dict): + if reasoning_tokens := completion_details.get("reasoning"): + token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING] = reasoning_tokens + if audio_tokens := completion_details.get("audio"): + token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO] = audio_tokens + + # Update latest attributes (will use the last ones after stream completes) + if output_attributes: + latest_output_attributes.update(output_attributes) + if tool_call_attributes: + latest_tool_call_attributes.update(tool_call_attributes) + if token_attributes: + latest_token_attributes.update(token_attributes) + yield res + + # Set all attributes after stream completes + if latest_output_attributes or latest_tool_call_attributes or latest_token_attributes: + span.set_attributes( + dict( + _flatten( + { + **latest_output_attributes, + **latest_tool_call_attributes, + **latest_token_attributes, + } ) ) - yield res + ) except Exception as exception: span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception))) span.record_exception(exception)