Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for custom toolchoice in LLMs #1102

Merged
merged 45 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e0e83e2
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
7379ca4
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
86bee50
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
4cd8a40
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 22, 2024
b4d5292
openai parallel tool call impl
jayeshp19 Nov 22, 2024
9dc1e07
wip
jayeshp19 Nov 22, 2024
344d040
tool choice wip
jayeshp19 Nov 22, 2024
d441874
wip
jayeshp19 Nov 22, 2024
e2ff9da
wip
jayeshp19 Nov 22, 2024
b5611c5
tool choice wip
jayeshp19 Nov 22, 2024
efbc8ed
tool choice wip
jayeshp19 Nov 22, 2024
6e23f77
wip
jayeshp19 Nov 22, 2024
9ae405d
added test cases
jayeshp19 Nov 22, 2024
ce9ca43
added test cases
jayeshp19 Nov 22, 2024
62070bd
assistant llm tool choice param
jayeshp19 Nov 22, 2024
a1c0dbe
minor
jayeshp19 Nov 22, 2024
45da685
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 25, 2024
439e2ff
updates
jayeshp19 Nov 25, 2024
06e6555
updates
jayeshp19 Nov 25, 2024
7d64c10
updates
jayeshp19 Nov 25, 2024
69d6a68
updates
jayeshp19 Nov 25, 2024
6e4c2f3
updates
jayeshp19 Nov 25, 2024
9cfb364
updates
jayeshp19 Nov 25, 2024
31006c2
updates
jayeshp19 Nov 25, 2024
4845c59
updates
jayeshp19 Nov 25, 2024
47ac432
updates
jayeshp19 Nov 25, 2024
fcc1e97
updates
jayeshp19 Nov 25, 2024
cf98684
updates
jayeshp19 Nov 25, 2024
dce1ad5
typecheck
jayeshp19 Nov 25, 2024
5a924ff
typecheck
jayeshp19 Nov 25, 2024
2dbcdbe
typecheck
jayeshp19 Nov 25, 2024
01b573a
typecheck
jayeshp19 Nov 25, 2024
ec8ff13
updates
jayeshp19 Nov 25, 2024
1c472d8
typecheck
jayeshp19 Nov 25, 2024
475330c
updates
jayeshp19 Nov 25, 2024
d7618bc
minor
jayeshp19 Nov 25, 2024
e726c56
minor
jayeshp19 Nov 25, 2024
d648ec7
minor
jayeshp19 Nov 25, 2024
a89616b
Update livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anth…
jayeshp19 Nov 26, 2024
b051534
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 26, 2024
9edfe29
updates
jayeshp19 Nov 26, 2024
d6d3cff
updates
jayeshp19 Nov 26, 2024
b2e984b
updates
jayeshp19 Nov 26, 2024
0aba89e
updates
jayeshp19 Nov 26, 2024
7490a95
updates
jayeshp19 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .changeset/real-phones-cheat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"livekit-plugins-anthropic": patch
"livekit-plugins-openai": patch
"livekit-agents": patch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add llama-index here

---

support for custom tool use in LLMs
2 changes: 2 additions & 0 deletions livekit-agents/livekit/agents/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Choice,
ChoiceDelta,
CompletionUsage,
FunctionToolChoice,
LLMCapabilities,
LLMStream,
)
Expand Down Expand Up @@ -52,4 +53,5 @@
"LLMCapabilities",
"FallbackAdapter",
"AvailabilityChangedEvent",
"FunctionToolChoice",
]
11 changes: 9 additions & 2 deletions livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncIterable, Literal
from typing import AsyncIterable, Literal, Union

from livekit.agents._exceptions import APIConnectionError, APIError

from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import FunctionContext
from .llm import LLM, ChatChunk, LLMStream
from .llm import LLM, ChatChunk, FunctionToolChoice, LLMStream

DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
max_retry=0, timeout=DEFAULT_API_CONNECT_OPTIONS.timeout
Expand Down Expand Up @@ -66,6 +66,8 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
return FallbackLLMStream(
llm=self,
Expand All @@ -75,6 +77,7 @@ def chat(
temperature=temperature,
n=n,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)


Expand All @@ -89,6 +92,8 @@ def __init__(
temperature: float | None,
n: int | None,
parallel_tool_calls: bool | None,
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> None:
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
Expand All @@ -97,6 +102,7 @@ def __init__(
self._temperature = temperature
self._n = n
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice

async def _try_generate(
self, *, llm: LLM, recovering: bool = False
Expand All @@ -108,6 +114,7 @@ async def _try_generate(
temperature=self._temperature,
n=self._n,
parallel_tool_calls=self._parallel_tool_calls,
tool_choice=self._tool_choice,
conn_options=dataclasses.replace(
self._conn_options,
max_retry=self._fallback_adapter._max_retry_per_llm,
Expand Down
18 changes: 17 additions & 1 deletion livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, AsyncIterable, AsyncIterator, Generic, Literal, TypeVar, Union
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Generic,
Literal,
TypedDict,
TypeVar,
Union,
)

from livekit import rtc
from livekit.agents._exceptions import APIConnectionError, APIError
Expand Down Expand Up @@ -51,6 +60,11 @@ class ChatChunk:
usage: CompletionUsage | None = None


class FunctionToolChoice(TypedDict):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FunctionToolChoice(TypedDict):
class ToolChoice(TypedDict):

type: Literal["function"]
name: str


TEvent = TypeVar("TEvent")


Expand Down Expand Up @@ -78,6 +92,8 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream": ...

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Awaitable, List, Tuple, get_args, get_origin
from typing import (
Any,
Awaitable,
List,
Literal,
Tuple,
Union,
get_args,
get_origin,
)

import httpx
from livekit import rtc
Expand All @@ -30,6 +39,7 @@
llm,
utils,
)
from livekit.agents.llm import FunctionToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

import anthropic
Expand All @@ -45,6 +55,8 @@ class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None
parallel_tool_calls: bool | None
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]] | None


class LLM(llm.LLM):
Expand All @@ -57,6 +69,10 @@ def __init__(
user: str | None = None,
client: anthropic.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
FunctionToolChoice, Literal["auto", "required", "none"]
] = "auto",
) -> None:
"""
Create a new instance of Anthropic LLM.
Expand All @@ -71,7 +87,13 @@ def __init__(
if api_key is None:
raise ValueError("Anthropic API key is required")

self._opts = LLMOptions(model=model, user=user, temperature=temperature)
self._opts = LLMOptions(
model=model,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
self._client = client or anthropic.AsyncClient(
api_key=api_key,
base_url=base_url,
Expand All @@ -95,9 +117,15 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]]
| None = None,
) -> "LLMStream":
if temperature is None:
temperature = self._opts.temperature
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls
if tool_choice is None:
tool_choice = self._opts.tool_choice

opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
Expand All @@ -106,9 +134,20 @@ def chat(
fncs_desc.append(_build_function_description(fnc))

opts["tools"] = fncs_desc

if fnc_ctx and parallel_tool_calls is not None:
opts["parallel_tool_calls"] = parallel_tool_calls
if tool_choice is not None:
anthropic_tool_choice: dict[str, Any] = {"type": "auto"}
if isinstance(tool_choice, dict):
jayeshp19 marked this conversation as resolved.
Show resolved Hide resolved
if tool_choice["type"] == "function":
anthropic_tool_choice = {
"type": "tool",
"name": tool_choice["name"],
}
elif isinstance(tool_choice, str):
if tool_choice == "required":
anthropic_tool_choice = {"type": "any"}
if parallel_tool_calls is not None and parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message = _latest_system_message(chat_ctx)
anthropic_ctx = _build_anthropic_context(chat_ctx.messages, id(self))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, MutableSet
from typing import Any, Callable, Dict, Literal, MutableSet, Union

import httpx
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm import FunctionToolChoice
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from openai import AsyncAssistantEventHandler, AsyncClient
Expand Down Expand Up @@ -172,6 +173,8 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[FunctionToolChoice, Literal["auto", "required", "none"]]
| None = None,
):
if n is not None:
logger.warning("OpenAI Assistants does not support the 'n' parameter")
Expand Down
Loading
Loading