Skip to content

Commit

Permalink
change default model for google stt and add aws llm test case (#1552)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Feb 28, 2025
1 parent 8946a7f commit 4935460
Show file tree
Hide file tree
Showing 11 changed files with 154 additions and 68 deletions.
5 changes: 5 additions & 0 deletions .changeset/brave-brooms-rest.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

google stt: change default model to `latest_long`
6 changes: 6 additions & 0 deletions .changeset/soft-tips-care.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-aws": patch
---

don't pass functions in params when tool choice is set to none
3 changes: 2 additions & 1 deletion livekit-plugins/install_plugins_editable.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ if [[ -z "$VIRTUAL_ENV" ]]; then
echo "You are not in a virtual environment."
exit 1
fi

pip install -e ./livekit-plugins-anthropic --config-settings editable_mode=strict
pip install -e ./livekit-plugins-aws --config-settings editable_mode=strict
pip install -e ./livekit-plugins-assemblyai --config-settings editable_mode=strict
pip install -e ./livekit-plugins-azure --config-settings editable_mode=strict
pip install -e ./livekit-plugins-cartesia --config-settings editable_mode=strict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def chat(

opts["tools"] = fncs_desc
if tool_choice is not None:
anthropic_tool_choice: dict[str, Any] = {"type": "auto"}
anthropic_tool_choice: dict[str, Any] | None = {"type": "auto"}
if isinstance(tool_choice, ToolChoice):
if tool_choice.type == "function":
anthropic_tool_choice = {
Expand All @@ -181,13 +181,20 @@ def chat(
elif isinstance(tool_choice, str):
if tool_choice == "required":
anthropic_tool_choice = {"type": "any"}
if parallel_tool_calls is not None and parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message: anthropic.types.TextBlockParam = _latest_system_message(
chat_ctx, caching=self._opts.caching
elif tool_choice == "none":
opts["tools"] = []
anthropic_tool_choice = None
if anthropic_tool_choice is not None:
if parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message: anthropic.types.TextBlockParam | None = (
_latest_system_message(chat_ctx, caching=self._opts.caching)
)
if latest_system_message:
opts["system"] = [latest_system_message]

anthropic_ctx = _build_anthropic_context(
chat_ctx.messages,
id(self),
Expand All @@ -197,7 +204,6 @@ def chat(

stream = self._client.messages.create(
max_tokens=opts.get("max_tokens", 1024),
system=[latest_system_message],
messages=collaped_anthropic_ctx,
model=self._opts.model,
temperature=temperature or anthropic.NOT_GIVEN,
Expand Down Expand Up @@ -366,7 +372,7 @@ def _parse_event(

def _latest_system_message(
chat_ctx: llm.ChatContext, caching: Literal["ephemeral"] | None = None
) -> anthropic.types.TextBlockParam:
) -> anthropic.types.TextBlockParam | None:
latest_system_message: llm.ChatMessage | None = None
for m in chat_ctx.messages:
if m.role == "system":
Expand All @@ -381,12 +387,14 @@ def _latest_system_message(
latest_system_str = " ".join(
[c for c in latest_system_message.content if isinstance(c, str)]
)
system_text_block = anthropic.types.TextBlockParam(
text=latest_system_str,
type="text",
cache_control=CACHE_CONTROL_EPHEMERAL if caching == "ephemeral" else None,
)
return system_text_block
if latest_system_str:
system_text_block = anthropic.types.TextBlockParam(
text=latest_system_str,
type="text",
cache_control=CACHE_CONTROL_EPHEMERAL if caching == "ephemeral" else None,
)
return system_text_block
return None


def _merge_messages(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,13 @@ def _build_image(image: llm.ChatImage, cache_key: Any) -> dict:
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
image._cache[cache_key] = utils.images.encode(image.image, opts)

return {
"image": {
"format": "jpeg",
"source": {
"bytes": image._cache[cache_key].encode("utf-8"),
"bytes": image._cache[cache_key],
},
}
}
Expand Down
64 changes: 45 additions & 19 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,33 +184,36 @@ async def _run(self) -> None:
try:
opts: dict[str, Any] = {}
messages, system_instruction = _build_aws_ctx(self._chat_ctx, id(self))
if messages[0]["role"] != "user":
messages.insert(
0,
{"role": "user", "content": [{"text": "(empty)"}]},
)
messages = _merge_messages(messages)

def _get_tool_config() -> dict[str, Any] | None:
if not (self._fnc_ctx and self._fnc_ctx.ai_functions):
return None

if self._fnc_ctx and self._fnc_ctx.ai_functions:
tools = _build_tools(self._fnc_ctx)
tool_config: dict[str, Any] = {"tools": tools}
config: dict[str, Any] = {"tools": tools}

if isinstance(self._tool_choice, ToolChoice):
tool_config["toolChoice"] = {
"tool": {"name": self._tool_choice.name}
}
config["toolChoice"] = {"tool": {"name": self._tool_choice.name}}
elif self._tool_choice == "required":
tool_config["toolChoice"] = {"any": {}}
config["toolChoice"] = {"any": {}}
elif self._tool_choice == "auto":
tool_config["toolChoice"] = {"auto": {}}
config["toolChoice"] = {"auto": {}}
else:
raise ValueError("aws bedrock llm: invalid tool choice")
return None

return config

tool_config = _get_tool_config()
if tool_config:
opts["toolConfig"] = tool_config

if self._additional_request_fields:
opts["additionalModelRequestFields"] = _strip_nones(
self._additional_request_fields
)
if system_instruction:
opts["system"] = [system_instruction]

inference_config = _strip_nones(
{
Expand All @@ -222,9 +225,8 @@ async def _run(self) -> None:
response = self._client.converse_stream(
modelId=self._model,
messages=messages,
system=[system_instruction],
inferenceConfig=inference_config,
**opts,
**_strip_nones(opts),
) # type: ignore

request_id = response["ResponseMetadata"]["RequestId"]
Expand Down Expand Up @@ -281,16 +283,16 @@ def _parse_chunk(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
return None

def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | None:
if not self._tool_call_id:
if self._tool_call_id is None:
logger.warning("aws bedrock llm: no tool call id in the response")
return None
if not self._fnc_name:
if self._fnc_name is None:
logger.warning("aws bedrock llm: no function name in the response")
return None
if not self._fnc_raw_arguments:
if self._fnc_raw_arguments is None:
logger.warning("aws bedrock llm: no function arguments in the response")
return None
if not self._fnc_ctx:
if self._fnc_ctx is None:
logger.warning(
"aws bedrock llm: stream tried to run function without function context"
)
Expand Down Expand Up @@ -320,5 +322,29 @@ def _try_build_function(self, request_id: str, chunk: dict) -> llm.ChatChunk | N
)


def _merge_messages(
messages: list[dict],
) -> list[dict]:
# Anthropic enforces alternating messages
combined_messages: list[dict] = []
for m in messages:
if len(combined_messages) == 0 or m["role"] != combined_messages[-1]["role"]:
combined_messages.append(m)
continue
last_message = combined_messages[-1]
if not isinstance(last_message["content"], list) or not isinstance(
m["content"], list
):
logger.error("message content is not a list")
continue

last_message["content"].extend(m["content"])

if len(combined_messages) == 0 or combined_messages[0]["role"] != "user":
combined_messages.insert(0, {"role": "user", "content": [{"text": "(empty)"}]})

return combined_messages


def _strip_nones(d: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in d.items() if v is not None}
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,7 @@ def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part
height=image.inference_height,
strategy="scale_aspect_fit",
)
encoded_data = utils.images.encode(image.image, opts)
image._cache[cache_key] = base64.b64encode(encoded_data).decode("utf-8")
image._cache[cache_key] = utils.images.encode(image.image, opts)

return types.Part.from_bytes(
data=image._cache[cache_key], mime_type="image/jpeg"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"medical_conversation",
"chirp",
"chirp_2",
"latest_long",
"latest_short",
]

SpeechLanguages = Literal[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class STTOptions:
interim_results: bool
punctuate: bool
spoken_punctuation: bool
model: SpeechModels
model: SpeechModels | str
sample_rate: int
keywords: List[tuple[str, float]] | None

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
interim_results: bool = True,
punctuate: bool = True,
spoken_punctuation: bool = False,
model: SpeechModels = "chirp_2",
model: SpeechModels | str = "latest_long",
location: str = "us-central1",
sample_rate: int = 16000,
credentials_info: dict | None = None,
Expand All @@ -106,6 +106,19 @@ def __init__(
Credentials must be provided, either by using the ``credentials_info`` dict, or reading
from the file specified in ``credentials_file`` or via Application Default Credentials as
described in https://cloud.google.com/docs/authentication/application-default-credentials
args:
languages(LanguageCode): list of language codes to recognize (default: "en-US")
detect_language(bool): whether to detect the language of the audio (default: True)
interim_results(bool): whether to return interim results (default: True)
punctuate(bool): whether to punctuate the audio (default: True)
spoken_punctuation(bool): whether to use spoken punctuation (default: False)
model(SpeechModels): the model to use for recognition default: "latest_long"
location(str): the location to use for recognition default: "us-central1"
sample_rate(int): the sample rate of the audio default: 16000
credentials_info(dict): the credentials info to use for recognition (default: None)
credentials_file(str): the credentials file to use for recognition (default: None)
keywords(List[tuple[str, float]]): list of keywords to recognize (default: None)
"""
super().__init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
Expand Down
38 changes: 33 additions & 5 deletions tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pytest
from livekit.agents import APIConnectionError, llm
from livekit.agents.llm import ChatContext, FunctionContext, TypeInfo, ai_callable
from livekit.plugins import anthropic, google, openai
from livekit.plugins import anthropic, aws, google, openai
from livekit.rtc import VideoBufferType, VideoFrame


Expand Down Expand Up @@ -101,7 +101,7 @@ def test_hashable_typeinfo():
pytest.param(lambda: anthropic.LLM(), id="anthropic"),
pytest.param(lambda: google.LLM(), id="google"),
pytest.param(lambda: google.LLM(vertexai=True), id="google-vertexai"),
# .param(lambda: aws.LLM(), id="aws"),
pytest.param(lambda: aws.LLM(), id="aws"),
]


Expand Down Expand Up @@ -131,6 +131,36 @@ async def test_chat(llm_factory: Callable[[], llm.LLM]):
assert len(text) > 0


@pytest.mark.parametrize("llm_factory", LLMS)
async def test_llm_chat_with_consecutive_messages(
llm_factory: callable,
):
input_llm = llm_factory()

chat_ctx = ChatContext()
chat_ctx.append(
text="Hello, How can I help you today?",
role="assistant",
)
chat_ctx.append(text="I see that you have a busy day ahead.", role="assistant")
chat_ctx.append(
text="Actually, I need some help with my recent order.", role="user"
)
chat_ctx.append(text="I want to cancel my order.", role="user")
chat_ctx.append(text="Sure, let me check your order details.", role="assistant")

stream = input_llm.chat(chat_ctx=chat_ctx)
collected_text = ""
async for chunk in stream:
if not chunk.choices:
continue
content = chunk.choices[0].delta.content
if content:
collected_text += content

assert len(collected_text) > 0, "Expected a non-empty response from the LLM chat"


@pytest.mark.parametrize("llm_factory", LLMS)
async def test_basic_fnc_calls(llm_factory: Callable[[], llm.LLM]):
input_llm = llm_factory()
Expand Down Expand Up @@ -348,9 +378,7 @@ async def test_tool_choice_options(
print(calls)

call_names = {call.call_info.function_info.name for call in calls}
if tool_choice == "none" and isinstance(input_llm, anthropic.LLM):
assert True
else:
if tool_choice == "none":
assert call_names == expected_calls, (
f"Test '{description}' failed: Expected calls {expected_calls}, but got {call_names}"
)
Expand Down
Loading

0 comments on commit 4935460

Please sign in to comment.