diff --git a/.changeset/tasty-brooms-exercise.md b/.changeset/tasty-brooms-exercise.md new file mode 100644 index 0000000000..d78a223fb2 --- /dev/null +++ b/.changeset/tasty-brooms-exercise.md @@ -0,0 +1,8 @@ +--- +"livekit-plugins-anthropic": patch +"livekit-plugins-google": patch +"livekit-plugins-openai": patch +"livekit-agents": patch +--- + +Added an additional field in LLM capabilities class to check if model providers support function call history within chat context without needing function definitions. diff --git a/livekit-agents/livekit/agents/llm/fallback_adapter.py b/livekit-agents/livekit/agents/llm/fallback_adapter.py index 4d63a2b9de..33e9c7eb93 100644 --- a/livekit-agents/livekit/agents/llm/fallback_adapter.py +++ b/livekit-agents/livekit/agents/llm/fallback_adapter.py @@ -12,7 +12,7 @@ from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from .chat_context import ChatContext from .function_context import CalledFunction, FunctionCallInfo, FunctionContext -from .llm import LLM, ChatChunk, LLMStream, ToolChoice +from .llm import LLM, ChatChunk, LLMCapabilities, LLMStream, ToolChoice DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions( max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout @@ -45,7 +45,16 @@ def __init__( if len(llm) < 1: raise ValueError("at least one LLM instance must be provided.") - super().__init__() + super().__init__( + capabilities=LLMCapabilities( + supports_choices_on_int=all( + t.capabilities.supports_choices_on_int for t in llm + ), + requires_persistent_functions=all( + t.capabilities.requires_persistent_functions for t in llm + ), + ) + ) self._llm_instances = llm self._attempt_timeout = attempt_timeout diff --git a/livekit-agents/livekit/agents/llm/llm.py b/livekit-agents/livekit/agents/llm/llm.py index 0275f64c17..07f28a5aaa 100644 --- a/livekit-agents/livekit/agents/llm/llm.py +++ b/livekit-agents/livekit/agents/llm/llm.py @@ -50,6 +50,9 @@ class Choice: @dataclass class LLMCapabilities: supports_choices_on_int: bool = True + """check whether the LLM supports integer enums choices as function arguments""" + requires_persistent_functions: bool = False + """if the LLM requires function definition when previous function calls exist in chat context""" @dataclass @@ -73,9 +76,11 @@ class LLM( rtc.EventEmitter[Union[Literal["metrics_collected"], TEvent]], Generic[TEvent], ): - def __init__(self) -> None: + def __init__(self, *, capabilities: LLMCapabilities | None = None) -> None: super().__init__() - self._capabilities = LLMCapabilities() + if capabilities is None: + capabilities = LLMCapabilities() + self._capabilities = capabilities self._label = f"{type(self).__module__}.{type(self).__name__}" @property diff --git a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py index da1beb9086..b50f8c878b 100644 --- a/livekit-agents/livekit/agents/pipeline/pipeline_agent.py +++ b/livekit-agents/livekit/agents/pipeline/pipeline_agent.py @@ -981,6 +981,7 @@ async def _execute_function_calls() -> None: fnc_ctx and new_speech_handle.fnc_nested_depth >= self._opts.max_nested_fnc_calls + and not self._llm.capabilities.requires_persistent_functions ): if len(fnc_ctx.ai_functions) > 1: logger.info( @@ -991,6 +992,7 @@ async def _execute_function_calls() -> None: }, ) fnc_ctx = None + answer_llm_stream = self._llm.chat( chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, diff --git a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py index 3af490211a..cf2b1f194f 100644 --- a/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py +++ b/livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anthropic/llm.py @@ -39,7 +39,7 @@ llm, utils, ) -from livekit.agents.llm import ToolChoice +from livekit.agents.llm import LLMCapabilities, ToolChoice from livekit.agents.llm.function_context import ( _create_ai_function_info, _is_optional_type, @@ -82,7 +82,13 @@ def __init__( ``api_key`` must be set to your Anthropic API key, either using the argument or by setting the ``ANTHROPIC_API_KEY`` environmental variable. """ - super().__init__() + + super().__init__( + capabilities=LLMCapabilities( + requires_persistent_functions=True, + supports_choices_on_int=True, + ) + ) # throw an error on our end api_key = api_key or os.environ.get("ANTHROPIC_API_KEY") diff --git a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/llm.py b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/llm.py index 16a0768769..7d09824cb7 100644 --- a/livekit-plugins/livekit-plugins-google/livekit/plugins/google/llm.py +++ b/livekit-plugins/livekit-plugins-google/livekit/plugins/google/llm.py @@ -27,7 +27,7 @@ llm, utils, ) -from livekit.agents.llm import ToolChoice, _create_ai_function_info +from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from google import genai @@ -99,8 +99,12 @@ def __init__( frequency_penalty (float, optional): Penalizes the model for repeating words. Defaults to None. tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto". """ - super().__init__() - self._capabilities = llm.LLMCapabilities(supports_choices_on_int=False) + super().__init__( + capabilities=LLMCapabilities( + supports_choices_on_int=False, + requires_persistent_functions=False, + ) + ) self._project_id = project or os.environ.get("GOOGLE_CLOUD_PROJECT", None) self._location = location or os.environ.get( "GOOGLE_CLOUD_LOCATION", "us-central1" diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py index 7df336e890..25917ab6bb 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/beta/assistant_llm.py @@ -23,7 +23,7 @@ import httpx from livekit import rtc from livekit.agents import llm, utils -from livekit.agents.llm import ToolChoice +from livekit.agents.llm import LLMCapabilities, ToolChoice from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions from openai import AsyncAssistantEventHandler, AsyncClient @@ -99,7 +99,12 @@ def __init__( base_url: str | None = None, on_file_uploaded: OnFileUploaded | None = None, ) -> None: - super().__init__() + super().__init__( + capabilities=LLMCapabilities( + supports_choices_on_int=True, + requires_persistent_functions=False, + ) + ) test_ctx = llm.ChatContext() if not hasattr(test_ctx, "_metadata"): diff --git a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py index fc4258c01e..716cce42a7 100644 --- a/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py +++ b/livekit-plugins/livekit-plugins-openai/livekit/plugins/openai/llm.py @@ -29,7 +29,11 @@ APITimeoutError, llm, ) -from livekit.agents.llm import ToolChoice, _create_ai_function_info +from livekit.agents.llm import ( + LLMCapabilities, + ToolChoice, + _create_ai_function_info, +) from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions import openai @@ -85,8 +89,12 @@ def __init__( ``api_key`` must be set to your OpenAI API key, either using the argument or by setting the ``OPENAI_API_KEY`` environmental variable. """ - super().__init__() - self._capabilities = llm.LLMCapabilities(supports_choices_on_int=True) + super().__init__( + capabilities=LLMCapabilities( + supports_choices_on_int=True, + requires_persistent_functions=False, + ) + ) self._opts = LLMOptions( model=model,