Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from openinference.instrumentation import (
get_attributes_from_context,
get_output_attributes,
get_llm_model_name_attributes,
get_llm_token_count_attributes,
safe_json_dumps,
)
from openinference.semconv.trace import (
Expand All @@ -48,6 +50,7 @@
ToolAttributes,
ToolCallAttributes,
)
from openinference.semconv.trace import SpanAttributes as LLM

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -97,6 +100,80 @@ def _flatten(mapping: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Attrib
yield key, value


def _get_token_value(obj: Any, attr_name: str) -> Optional[int]:
"""Extract token value from object attribute."""
if hasattr(obj, attr_name):
value = getattr(obj, attr_name)
if value is not None:
return value
return None


def _extract_details_from_object(details: Any, mapping: Dict[str, str]) -> Dict[str, int]:
"""Extract details from object attributes or dict."""
result = {}
if isinstance(details, dict):
for key, attr_name in mapping.items():
value = details.get(attr_name)
if value is not None:
result[key] = value
else:
for key, attr_name in mapping.items():
value = _get_token_value(details, attr_name)
if value is not None:
result[key] = value
return result


def _extract_token_usage(result: Any) -> Dict[str, Any]:
"""Extract token usage information from CreateResult."""
if not hasattr(result, 'usage') or not result.usage:
return {}

usage = result.usage
token_usage = {}

prompt_tokens = _get_token_value(usage, 'prompt_tokens')
if prompt_tokens is not None:
token_usage["prompt"] = prompt_tokens

completion_tokens = _get_token_value(usage, 'completion_tokens')
if completion_tokens is not None:
token_usage["completion"] = completion_tokens

details = None
if hasattr(usage, 'prompt_tokens_details'):
details = usage.prompt_tokens_details
elif hasattr(usage, '__dict__') and 'prompt_tokens_details' in usage.__dict__:
details = usage.__dict__['prompt_tokens_details']

if details:
prompt_details = _extract_details_from_object(
details,
{"cache_read": "cached_tokens", "audio": "audio_tokens", "cache_input": "text_tokens"}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Token Attribution Fails for Cached Input

The mapping includes "cache_input": "text_tokens", but get_llm_token_count_attributes doesn't handle the cache_input key in prompt_details. If text_tokens exists in the response, it gets extracted into the token usage dict but is never converted to a span attribute, making this extraction ineffective.

Fix in Cursor Fix in Web

)
if prompt_details:
token_usage["prompt_details"] = prompt_details

details = None
if hasattr(usage, 'completion_tokens_details'):
details = usage.completion_tokens_details
elif hasattr(usage, '__dict__') and 'completion_tokens_details' in usage.__dict__:
details = usage.__dict__['completion_tokens_details']

if details:
completion_details = _extract_details_from_object(
details,
{"reasoning": "reasoning_tokens", "audio": "audio_tokens"}
)
if completion_details:
token_usage["completion_details"] = completion_details

if "prompt" in token_usage and "completion" in token_usage:
token_usage["total"] = token_usage["prompt"] + token_usage["completion"]

return token_usage

def _get_input_value(method: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
"""
Parses a method call's inputs into a JSON string. Ensures a consistent
Expand Down Expand Up @@ -466,6 +543,9 @@ async def __call__(
if param_name in method_signature.parameters:
valid_kwargs[param_name] = param_value

# Extract model name from instance
model_name = getattr(instance, 'model', None)

with self._tracer.start_as_current_span(
span_name,
attributes=dict(
Expand All @@ -474,6 +554,7 @@ async def __call__(
OPENINFERENCE_SPAN_KIND: LLM,
**dict(_llm_messages_attributes(messages, "input")),
**dict(_get_llm_tool_attributes(tools)),
**dict(get_llm_model_name_attributes(model_name)),
**dict(get_attributes_from_context()),
}
)
Expand All @@ -488,13 +569,25 @@ async def __call__(
# Extract output attributes and process tool calls in response
output_attributes = dict(get_output_attributes(result))
tool_call_attributes = dict(_extract_output_tool_calls(result))


# Extract token usage information
token_usage = _extract_token_usage(result)
token_attributes = dict(get_llm_token_count_attributes(token_usage))

# Add completion_details if available (reasoning_tokens, audio_tokens)
if completion_details := token_usage.get("completion_details"):
if isinstance(completion_details, dict):
if reasoning_tokens := completion_details.get("reasoning"):
token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING] = reasoning_tokens
if audio_tokens := completion_details.get("audio"):
token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO] = audio_tokens
span.set_attributes(
dict(
_flatten(
{
**output_attributes,
**tool_call_attributes,
**token_attributes,
}
)
)
Expand Down Expand Up @@ -535,6 +628,28 @@ async def __call__(
for param_name, param_value in kwargs.items():
if param_name in method_signature.parameters:
valid_kwargs[param_name] = param_value

# Ensure stream_options includes include_usage for token metrics
if valid_kwargs.get("include_usage") is None:
extra_create_args = valid_kwargs.get("extra_create_args", {})
if isinstance(extra_create_args, dict):
stream_options = extra_create_args.get("stream_options", {})
if not stream_options.get("include_usage"):
# Inject stream_options to ensure we get usage data
if "extra_create_args" not in valid_kwargs:
valid_kwargs["extra_create_args"] = {}
extra_create_args = valid_kwargs["extra_create_args"]
if not isinstance(extra_create_args, dict):
valid_kwargs["extra_create_args"] = dict(extra_create_args)
extra_create_args = valid_kwargs["extra_create_args"]
if "stream_options" not in extra_create_args:
extra_create_args["stream_options"] = {}
extra_create_args["stream_options"]["include_usage"] = True
elif "include_usage" in method_signature.parameters:
# Set include_usage parameter directly if supported
valid_kwargs["include_usage"] = True

model_name = getattr(instance, 'model', None)

span_name = f"{instance.__class__.__name__}.create_stream"
with self._tracer.start_as_current_span(
Expand All @@ -545,31 +660,60 @@ async def __call__(
OPENINFERENCE_SPAN_KIND: LLM,
**dict(_llm_messages_attributes(messages, "input")),
**dict(_get_llm_tool_attributes(tools)),
**dict(get_llm_model_name_attributes(model_name)),
**dict(get_attributes_from_context()),
}
)
),
record_exception=False,
set_status_on_exception=False,
) as span:
# Track the latest token usage and output attributes
latest_token_attributes = {}
latest_output_attributes = {}
latest_tool_call_attributes = {}

try:
async for res in wrapped(*args, **valid_kwargs):
if isinstance(res, CreateResult):
# Extract output attributes and process tool calls in response
output_attributes = dict(get_output_attributes(res))
tool_call_attributes = dict(_extract_output_tool_calls(res))

span.set_attributes(
dict(
_flatten(
{
**output_attributes,
**tool_call_attributes,
}
)

# Extract token usage information
token_usage = _extract_token_usage(res)
token_attributes = dict(get_llm_token_count_attributes(token_usage))

# Add completion_details if available (reasoning_tokens, audio_tokens)
if completion_details := token_usage.get("completion_details"):
if isinstance(completion_details, dict):
if reasoning_tokens := completion_details.get("reasoning"):
token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING] = reasoning_tokens
if audio_tokens := completion_details.get("audio"):
token_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_AUDIO] = audio_tokens

# Update latest attributes (will use the last ones after stream completes)
if output_attributes:
latest_output_attributes.update(output_attributes)
if tool_call_attributes:
latest_tool_call_attributes.update(tool_call_attributes)
if token_attributes:
latest_token_attributes.update(token_attributes)
yield res

# Set all attributes after stream completes
if latest_output_attributes or latest_tool_call_attributes or latest_token_attributes:
span.set_attributes(
dict(
_flatten(
{
**latest_output_attributes,
**latest_tool_call_attributes,
**latest_token_attributes,
}
)
)
yield res
)
except Exception as exception:
span.set_status(trace_api.Status(trace_api.StatusCode.ERROR, str(exception)))
span.record_exception(exception)
Expand Down