Skip to content

Commit

Permalink
fix: Do not set fnc_ctx to None for anthropic at max depth (#1441)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Feb 6, 2025
1 parent 5f977fa commit 0865638
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 14 deletions.
8 changes: 8 additions & 0 deletions .changeset/tasty-brooms-exercise.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-google": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
---

Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions.
13 changes: 11 additions & 2 deletions livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import CalledFunction, FunctionCallInfo, FunctionContext
from .llm import LLM, ChatChunk, LLMStream, ToolChoice
from .llm import LLM, ChatChunk, LLMCapabilities, LLMStream, ToolChoice

DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
Expand Down Expand Up @@ -45,7 +45,16 @@ def __init__(
if len(llm) < 1:
raise ValueError("at least one LLM instance must be provided.")

super().__init__()
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=all(
t.capabilities.supports_choices_on_int for t in llm
),
requires_persistent_functions=all(
t.capabilities.requires_persistent_functions for t in llm
),
)
)

self._llm_instances = llm
self._attempt_timeout = attempt_timeout
Expand Down
9 changes: 7 additions & 2 deletions livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class Choice:
@dataclass
class LLMCapabilities:
supports_choices_on_int: bool = True
"""check whether the LLM supports integer enums choices as function arguments"""
requires_persistent_functions: bool = False
"""if the LLM requires function definition when previous function calls exist in chat context"""


@dataclass
Expand All @@ -73,9 +76,11 @@ class LLM(
rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]],
Generic[TEvent],
):
def __init__(self) -> None:
def __init__(self, *, capabilities: LLMCapabilities | None = None) -> None:
super().__init__()
self._capabilities = LLMCapabilities()
if capabilities is None:
capabilities = LLMCapabilities()
self._capabilities = capabilities
self._label = f"{type(self).__module__}.{type(self).__name__}"

@property
Expand Down
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/pipeline/pipeline_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,7 @@ async def _execute_function_calls() -> None:
fnc_ctx
and new_speech_handle.fnc_nested_depth
>= self._opts.max_nested_fnc_calls
and not self._llm.capabilities.requires_persistent_functions
):
if len(fnc_ctx.ai_functions) > 1:
logger.info(
Expand All @@ -991,6 +992,7 @@ async def _execute_function_calls() -> None:
},
)
fnc_ctx = None

answer_llm_stream = self._llm.chat(
chat_ctx=chat_ctx,
fnc_ctx=fnc_ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
llm,
utils,
)
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.llm.function_context import (
_create_ai_function_info,
_is_optional_type,
Expand Down Expand Up @@ -82,7 +82,13 @@ def __init__(
``api_key`` must be set to your Anthropic API key, either using the argument or by setting
the ``ANTHROPIC_API_KEY`` environmental variable.
"""
super().__init__()

super().__init__(
capabilities=LLMCapabilities(
requires_persistent_functions=True,
supports_choices_on_int=True,
)
)

# throw an error on our end
api_key = api_key or os.environ.get("ANTHROPIC_API_KEY")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
llm,
utils,
)
from livekit.agents.llm import ToolChoice, _create_ai_function_info
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from google import genai
Expand Down Expand Up @@ -99,8 +99,12 @@ def __init__(
frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None.
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
"""
super().__init__()
self._capabilities = llm.LLMCapabilities(supports_choices_on_int=False)
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=False,
requires_persistent_functions=False,
)
)
self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None)
self._location = location or os.environ.get(
"GOOGLE_CLOUD_LOCATION", "us-central1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import httpx
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import ToolChoice
from livekit.agents.llm import LLMCapabilities, ToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from openai import AsyncAssistantEventHandler, AsyncClient
Expand Down Expand Up @@ -99,7 +99,12 @@ def __init__(
base_url: str | None = None,
on_file_uploaded: OnFileUploaded | None = None,
) -> None:
super().__init__()
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)

test_ctx = llm.ChatContext()
if not hasattr(test_ctx, "_metadata"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
APITimeoutError,
llm,
)
from livekit.agents.llm import ToolChoice, _create_ai_function_info
from livekit.agents.llm import (
LLMCapabilities,
ToolChoice,
_create_ai_function_info,
)
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import openai
Expand Down Expand Up @@ -85,8 +89,12 @@ def __init__(
``api_key`` must be set to your OpenAI API key, either using the argument or by setting the
``OPENAI_API_KEY`` environmental variable.
"""
super().__init__()
self._capabilities = llm.LLMCapabilities(supports_choices_on_int=True)
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=False,
)
)

self._opts = LLMOptions(
model=model,
Expand Down

0 comments on commit 0865638

Please sign in to comment.