Skip to content

Commit

Permalink
chore: fix type check
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed Nov 11, 2024
1 parent 41e8152 commit 28e8e7b
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 28 deletions.
4 changes: 0 additions & 4 deletions examples/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,6 @@ async def get_weather(

# create a chat context with chat history
chat_ctx = llm.ChatContext()

# Add some test context to verify if the sync_chat_ctx works
# FIXME: OAI realtime API does not support this properly when the chat context is too many
# It may answer with the text responses only for some cases
chat_ctx.append(text="I'm planning a trip to Paris next month.", role="user")
chat_ctx.append(
text="How exciting! Paris is a beautiful city. I'd be happy to suggest some must-visit places and help you plan your trip.",
Expand Down
9 changes: 4 additions & 5 deletions livekit-agents/livekit/agents/llm/chat_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,16 @@ class ChatAudio:
@dataclass
class ChatMessage:
role: ChatRole
id: str | None = None # used by the OAI realtime API
id: str = field(
default_factory=lambda: utils.shortuuid("item_")
) # used by the OAI realtime API
name: str | None = None
content: ChatContent | list[ChatContent] | None = None
tool_calls: list[function_context.FunctionCallInfo] | None = None
tool_call_id: str | None = None
tool_exception: Exception | None = None
_metadata: dict[str, Any] = field(default_factory=dict, repr=False, init=False)

def __post_init__(self):
if self.id is None:
self.id = utils.shortuuid("item_")

@staticmethod
def create_tool_from_called_function(
called_function: function_context.CalledFunction,
Expand Down Expand Up @@ -97,6 +95,7 @@ def create(
role: ChatRole = "system",
id: str | None = None,
) -> "ChatMessage":
id = id or utils.shortuuid("item_")
if len(images) == 0:
return ChatMessage(role=role, content=text, id=id)
else:
Expand Down
6 changes: 3 additions & 3 deletions livekit-agents/livekit/agents/utils/_message_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def _compute_list_changes(old_list: list[T], new_list: list[T]) -> MessageChange
first_idx = old_list.index(new_list[0])
except ValueError:
# Special case: if first item is new, delete everything
prev_item = None
to_add = []
prev_item: T | None = None
to_add: list[tuple[T | None, T]] = []
for x in new_list:
to_add.append((prev_item, x))
prev_item = x
Expand Down Expand Up @@ -106,7 +106,7 @@ def _compute_list_changes(old_list: list[T], new_list: list[T]) -> MessageChange
to_delete.extend(x for x in remaining_old if x not in kept_items)

# Compute items to add by following new list order
to_add: list[tuple[T | None, T]] = []
to_add = []
prev_item = None
for x in new_list:
if x not in kept_items:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,30 +214,42 @@ class InputAudioBufferClear(TypedDict):
type: Literal["input_audio_buffer.clear"]

class UserItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["user"]
content: list[InputTextContent | InputAudioContent]

class AssistantItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["assistant"]
content: list[TextContent]

class SystemItemCreate(TypedDict):
id: str | None
type: Literal["message"]
role: Literal["system"]
content: list[InputTextContent]

class FunctionCallOutputItemCreate(TypedDict):
id: str | None
type: Literal["function_call_output"]
call_id: str
output: str

class FunctionCallItemCreate(TypedDict):
id: str | None
type: Literal["function_call"]
call_id: str
name: str
arguments: str

ConversationItemCreateContent = Union[
UserItemCreate,
AssistantItemCreate,
SystemItemCreate,
FunctionCallOutputItemCreate,
FunctionCallItemCreate,
]

class ConversationItemCreate(TypedDict):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from copy import deepcopy
from dataclasses import dataclass
from typing import AsyncIterable, Callable, Literal, overload
from typing import AsyncIterable, Callable, Literal, cast, overload
from urllib.parse import urlencode

import aiohttp
Expand All @@ -31,6 +31,8 @@
"response_content_done",
"response_output_done",
"response_done",
"function_calls_collected",
"function_calls_finished",
]


Expand Down Expand Up @@ -505,9 +507,10 @@ def create(
}
else:
# function_call
if not message.tool_calls:
if not message.tool_calls or message.name is None:
logger.warning(
"function call message has no tool calls",
"function call message has no name or tool calls: %s",
message,
extra=self._sess.logging_extra(),
)
return
Expand All @@ -530,6 +533,13 @@ def create(
},
}
else:
if message_content is None:
logger.warning(
"message content is None, skipping: %s",
message,
extra=self._sess.logging_extra(),
)
return
if not isinstance(message_content, list):
message_content = [message_content]

Expand Down Expand Up @@ -595,18 +605,23 @@ def create(
system_contents: list[api_proto.InputTextContent] = []
for cnt in message_content:
if isinstance(cnt, str):
system_contents.append(
{
"id": message.id,
"type": "input_text",
"text": cnt,
}
)
system_contents.append({"type": "input_text", "text": cnt})
elif isinstance(cnt, llm.ChatAudio):
logger.warning(
"audio content in system message is not supported"
)

event = {
"type": "conversation.item.create",
"previous_item_id": previous_item_id,
"item": {
"id": message.id,
"type": "message",
"role": "system",
"content": system_contents,
},
}

if event is None:
logger.warning(
"chat message is not supported inside the realtime API %s",
Expand Down Expand Up @@ -654,7 +669,7 @@ async def acreate(
async def adelete(self, *, item_id: str) -> None:
fut = asyncio.Future[None]()
self._sess._item_deleted_futs[item_id] = fut
self.delete(item_id)
self.delete(item_id=item_id)
await fut
del self._sess._item_deleted_futs[item_id]

Expand Down Expand Up @@ -851,9 +866,9 @@ async def set_chat_ctx(self, new_ctx: llm.ChatContext) -> None:
},
)

# append an empty audio message if all messages are text
if new_ctx.messages and not any(
isinstance(msg.content, llm.ChatAudio) for msg in new_ctx.messages
# append an empty audio message if all new messages are text
if changes.to_add and not any(
isinstance(msg.content, llm.ChatAudio) for _, msg in changes.to_add
):
# Patch: add an empty audio message to the chat context
# to set the API in audio mode
Expand Down Expand Up @@ -901,8 +916,8 @@ def _update_converstation_item_content(
) -> None:
item = self._remote_converstation_items.get(item_id)
if item is None:
logger.error(
"conversation item not found",
logger.warning(
"conversation item not found, skipping update",
extra={"item_id": item_id},
)
return
Expand Down Expand Up @@ -1118,11 +1133,13 @@ def _handle_conversation_item_created(
# Leave the content empty and fill it in later from the content parts
if item_type == "message":
# Handle message items (system/user/assistant)
item = cast(api_proto.SystemItem | api_proto.UserItem, item)
role = item["role"]
message = llm.ChatMessage(id=item_id, role=role)
if item.get("content"):
content = item["content"][0]
if content["type"] in ("text", "input_text"):
content = cast(api_proto.InputTextContent, content)
message.content = content["text"]
elif content["type"] == "input_audio" and content.get("audio"):
audio_data = base64.b64decode(content["audio"])
Expand All @@ -1137,6 +1154,7 @@ def _handle_conversation_item_created(

elif item_type == "function_call":
# Handle function call items
item = cast(api_proto.FunctionCallItem, item)
message = llm.ChatMessage(
id=item_id,
role="assistant",
Expand All @@ -1146,6 +1164,7 @@ def _handle_conversation_item_created(

elif item_type == "function_call_output":
# Handle function call output items
item = cast(api_proto.FunctionCallOutputItem, item)
message = llm.ChatMessage(
id=item_id,
role="tool",
Expand Down

0 comments on commit 28e8e7b

Please sign in to comment.