Skip to content
42 changes: 39 additions & 3 deletions libs/partners/perplexity/langchain_perplexity/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
SystemMessageChunk,
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata, subtract_usage
from langchain_core.messages.ai import (
OutputTokenDetails,
UsageMetadata,
subtract_usage,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
Expand All @@ -49,13 +53,28 @@ def _is_pydantic_class(obj: Any) -> bool:


def _create_usage_metadata(token_usage: dict) -> UsageMetadata:
"""Create UsageMetadata from Perplexity token usage data.

Args:
token_usage: Dictionary containing token usage information from Perplexity API.

Returns:
UsageMetadata with properly structured token counts and details.
"""
input_tokens = token_usage.get("prompt_tokens", 0)
output_tokens = token_usage.get("completion_tokens", 0)
total_tokens = token_usage.get("total_tokens", input_tokens + output_tokens)

# Build output_token_details for Perplexity-specific fields
output_token_details: OutputTokenDetails = {}
output_token_details["reasoning"] = token_usage.get("reasoning_tokens", 0)
output_token_details["citation_tokens"] = token_usage.get("citation_tokens", 0) # type: ignore[typeddict-unknown-key]
Comment on lines +69 to +71
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) might make sense to only populate the keys if they are present in token_usage.


return UsageMetadata(
input_tokens=input_tokens,
output_tokens=output_tokens,
total_tokens=total_tokens,
output_token_details=output_token_details,
)


Expand Down Expand Up @@ -301,6 +320,7 @@ def _stream(
prev_total_usage: UsageMetadata | None = None

added_model_name: bool = False
added_search_queries: bool = False
for chunk in stream_resp:
if not isinstance(chunk, dict):
chunk = chunk.model_dump()
Expand Down Expand Up @@ -332,6 +352,13 @@ def _stream(
generation_info["model_name"] = model_name
added_model_name = True

# Add num_search_queries to generation_info if present
if total_usage := chunk.get("usage"):
if num_search_queries := total_usage.get("num_search_queries"):
if not added_search_queries:
generation_info["num_search_queries"] = num_search_queries
added_search_queries = True

chunk = self._convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
Expand Down Expand Up @@ -369,20 +396,29 @@ def _generate(
params = {**params, **kwargs}
response = self.client.chat.completions.create(messages=message_dicts, **params)
if usage := getattr(response, "usage", None):
usage_metadata = _create_usage_metadata(usage.model_dump())
usage_dict = usage.model_dump()
usage_metadata = _create_usage_metadata(usage_dict)
else:
usage_metadata = None
usage_dict = {}

additional_kwargs = {}
for attr in ["citations", "images", "related_questions", "search_results"]:
if hasattr(response, attr):
additional_kwargs[attr] = getattr(response, attr)

# Build response_metadata with model_name and num_search_queries
response_metadata: dict[str, Any] = {
"model_name": getattr(response, "model", self.model)
}
if num_search_queries := usage_dict.get("num_search_queries"):
response_metadata["num_search_queries"] = num_search_queries

message = AIMessage(
content=response.choices[0].message.content,
additional_kwargs=additional_kwargs,
usage_metadata=usage_metadata,
response_metadata={"model_name": getattr(response, "model", self.model)},
response_metadata=response_metadata,
)
return ChatResult(generations=[ChatGeneration(message=message)])

Expand Down