Skip to content

Commit

Permalink
fix: expose multimodal agent metrics (#1080)
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw authored Nov 17, 2024
1 parent 383f102 commit 6f330b2
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 6 deletions.
6 changes: 6 additions & 0 deletions .changeset/tough-boats-appear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Expose multimodal agent metrics
4 changes: 4 additions & 0 deletions livekit-agents/livekit/agents/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from .base import (
AgentMetrics,
LLMMetrics,
MultimodalLLMError,
MultimodalLLMMetrics,
PipelineEOUMetrics,
PipelineLLMMetrics,
PipelineSTTMetrics,
Expand All @@ -16,6 +18,8 @@

__all__ = [
"LLMMetrics",
"MultimodalLLMError",
"MultimodalLLMMetrics",
"AgentMetrics",
"PipelineEOUMetrics",
"PipelineSTTMetrics",
Expand Down
26 changes: 26 additions & 0 deletions livekit-agents/livekit/agents/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,31 @@ class PipelineVADMetrics(VADMetrics):
pass


@dataclass
class MultimodalLLMError(Error):
type: str | None
reason: str | None = None
code: str | None = None
message: str | None = None


@dataclass
class MultimodalLLMMetrics(LLMMetrics):
@dataclass
class InputTokenDetails:
cached_tokens: int
text_tokens: int
audio_tokens: int

@dataclass
class OutputTokenDetails:
text_tokens: int
audio_tokens: int

input_token_details: InputTokenDetails
output_token_details: OutputTokenDetails


AgentMetrics = Union[
STTMetrics,
LLMMetrics,
Expand All @@ -108,4 +133,5 @@ class PipelineVADMetrics(VADMetrics):
PipelineLLMMetrics,
PipelineTTSMetrics,
PipelineVADMetrics,
MultimodalLLMMetrics,
]
6 changes: 6 additions & 0 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from livekit import rtc
from livekit.agents import llm, stt, tokenize, transcription, utils, vad
from livekit.agents.llm import ChatMessage
from livekit.agents.metrics import MultimodalLLMMetrics

from .._constants import ATTRIBUTE_AGENT_STATE
from .._types import AgentState
Expand All @@ -24,6 +25,7 @@
"agent_speech_interrupted",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]


Expand Down Expand Up @@ -240,6 +242,10 @@ def _function_calls_collected(fnc_call_infos: list[llm.FunctionCallInfo]):
def _function_calls_finished(called_fncs: list[llm.CalledFunction]):
self.emit("function_calls_finished", called_fncs)

@self._session.on("metrics_collected")
def _metrics_collected(metrics: MultimodalLLMMetrics):
self.emit("metrics_collected", metrics)

def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import base64
import os
import time
from copy import deepcopy
from dataclasses import dataclass
from typing import AsyncIterable, Literal, Union, cast, overload
Expand All @@ -12,6 +13,7 @@
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import _oai_api
from livekit.agents.metrics import MultimodalLLMError, MultimodalLLMMetrics
from typing_extensions import TypedDict

from . import api_proto, remote_items
Expand All @@ -33,6 +35,7 @@
"response_done",
"function_calls_collected",
"function_calls_finished",
"metrics_collected",
]


Expand Down Expand Up @@ -66,6 +69,10 @@ class RealtimeResponse:
"""usage of the response"""
done_fut: asyncio.Future[None]
"""future that will be set when the response is completed"""
_created_timestamp: float
"""timestamp when the response was created"""
_first_token_timestamp: float | None = None
"""timestamp when the first token was received"""


@dataclass
Expand Down Expand Up @@ -695,6 +702,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
) -> None:
super().__init__()
self._label = f"{type(self).__module__}.{type(self).__name__}"
self._main_atask = asyncio.create_task(
self._main_task(), name="openai-realtime-session"
)
Expand Down Expand Up @@ -1203,6 +1211,7 @@ def _handle_response_created(
output=[],
usage=response.get("usage"),
done_fut=done_fut,
_created_timestamp=time.time(),
)
self._pending_responses[new_response.id] = new_response
self.emit("response_created", new_response)
Expand Down Expand Up @@ -1257,6 +1266,7 @@ def _handle_response_content_part_added(
content_type=content_type,
)
output.content.append(new_content)
response._first_token_timestamp = time.time()
self.emit("response_content_added", new_content)

def _handle_response_audio_delta(
Expand Down Expand Up @@ -1361,15 +1371,19 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon
response.status_details = response_data.get("status_details")
response.usage = response_data.get("usage")

metrics_error = None
cancelled = False
if response.status == "failed":
assert response.status_details is not None

error = response.status_details.get("error")
code: str | None = None
message: str | None = None
if error is not None:
code = error.get("code") # type: ignore
message = error.get("message") # type: ignore
error = response.status_details.get("error", {})
code: str | None = error.get("code") # type: ignore
message: str | None = error.get("message") # type: ignore
metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
code=code,
message=message,
)

logger.error(
"response generation failed",
Expand All @@ -1379,13 +1393,57 @@ def _handle_response_done(self, response_done: api_proto.ServerEvent.ResponseDon
assert response.status_details is not None
reason = response.status_details.get("reason")

metrics_error = MultimodalLLMError(
type=response.status_details.get("type"),
reason=reason, # type: ignore
)

logger.warning(
"response generation incomplete",
extra={"reason": reason, **self.logging_extra()},
)
elif response.status == "cancelled":
cancelled = True

self.emit("response_done", response)

# calculate metrics
ttft = -1.0
if response._first_token_timestamp is not None:
ttft = response._first_token_timestamp - response._created_timestamp
duration = time.time() - response._created_timestamp

usage = response.usage or {} # type: ignore
metrics = MultimodalLLMMetrics(
timestamp=response._created_timestamp,
request_id=response.id,
ttft=ttft,
duration=duration,
cancelled=cancelled,
label=self._label,
completion_tokens=usage.get("output_tokens", 0),
prompt_tokens=usage.get("input_tokens", 0),
total_tokens=usage.get("total_tokens", 0),
tokens_per_second=usage.get("output_tokens", 0) / duration,
error=metrics_error,
input_token_details=MultimodalLLMMetrics.InputTokenDetails(
cached_tokens=usage.get("input_token_details", {}).get(
"cached_tokens", 0
),
text_tokens=usage.get("input_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("input_token_details", {}).get(
"audio_tokens", 0
),
),
output_token_details=MultimodalLLMMetrics.OutputTokenDetails(
text_tokens=usage.get("output_token_details", {}).get("text_tokens", 0),
audio_tokens=usage.get("output_token_details", {}).get(
"audio_tokens", 0
),
),
)
self.emit("metrics_collected", metrics)

def _get_content(self, ptr: _ContentPtr) -> RealtimeContent:
response = self._pending_responses[ptr["response_id"]]
output = response.output[ptr["output_index"]]
Expand Down

0 comments on commit 6f330b2

Please sign in to comment.