Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for custom toolchoice in LLMs #1102

Merged
merged 45 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
e0e83e2
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
7379ca4
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
86bee50
support for custom tooluse in anthropic
jayeshp19 Nov 17, 2024
4cd8a40
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 22, 2024
b4d5292
openai parallel tool call impl
jayeshp19 Nov 22, 2024
9dc1e07
wip
jayeshp19 Nov 22, 2024
344d040
tool choice wip
jayeshp19 Nov 22, 2024
d441874
wip
jayeshp19 Nov 22, 2024
e2ff9da
wip
jayeshp19 Nov 22, 2024
b5611c5
tool choice wip
jayeshp19 Nov 22, 2024
efbc8ed
tool choice wip
jayeshp19 Nov 22, 2024
6e23f77
wip
jayeshp19 Nov 22, 2024
9ae405d
added test cases
jayeshp19 Nov 22, 2024
ce9ca43
added test cases
jayeshp19 Nov 22, 2024
62070bd
assistant llm tool choice param
jayeshp19 Nov 22, 2024
a1c0dbe
minor
jayeshp19 Nov 22, 2024
45da685
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 25, 2024
439e2ff
updates
jayeshp19 Nov 25, 2024
06e6555
updates
jayeshp19 Nov 25, 2024
7d64c10
updates
jayeshp19 Nov 25, 2024
69d6a68
updates
jayeshp19 Nov 25, 2024
6e4c2f3
updates
jayeshp19 Nov 25, 2024
9cfb364
updates
jayeshp19 Nov 25, 2024
31006c2
updates
jayeshp19 Nov 25, 2024
4845c59
updates
jayeshp19 Nov 25, 2024
47ac432
updates
jayeshp19 Nov 25, 2024
fcc1e97
updates
jayeshp19 Nov 25, 2024
cf98684
updates
jayeshp19 Nov 25, 2024
dce1ad5
typecheck
jayeshp19 Nov 25, 2024
5a924ff
typecheck
jayeshp19 Nov 25, 2024
2dbcdbe
typecheck
jayeshp19 Nov 25, 2024
01b573a
typecheck
jayeshp19 Nov 25, 2024
ec8ff13
updates
jayeshp19 Nov 25, 2024
1c472d8
typecheck
jayeshp19 Nov 25, 2024
475330c
updates
jayeshp19 Nov 25, 2024
d7618bc
minor
jayeshp19 Nov 25, 2024
e726c56
minor
jayeshp19 Nov 25, 2024
d648ec7
minor
jayeshp19 Nov 25, 2024
a89616b
Update livekit-plugins/livekit-plugins-anthropic/livekit/plugins/anth…
jayeshp19 Nov 26, 2024
b051534
Merge branch 'main' of https://github.com/livekit/agents into jp/anth…
jayeshp19 Nov 26, 2024
9edfe29
updates
jayeshp19 Nov 26, 2024
d6d3cff
updates
jayeshp19 Nov 26, 2024
b2e984b
updates
jayeshp19 Nov 26, 2024
0aba89e
updates
jayeshp19 Nov 26, 2024
7490a95
updates
jayeshp19 Nov 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/real-phones-cheat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-anthropic": patch
---

support for custom tool use in anthropic
10 changes: 9 additions & 1 deletion livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncIterable, Literal
from typing import AsyncIterable, Literal, TypedDict, Union

from livekit.agents._exceptions import APIConnectionError, APIError

Expand All @@ -31,6 +31,11 @@ class AvailabilityChangedEvent:
available: bool


class ToolChoice(TypedDict, total=False):
type: Literal["auto", "any", "tool", "none", "required"]
name: str


class FallbackAdapter(
LLM[Literal["llm_availability_changed"]],
):
Expand Down Expand Up @@ -66,6 +71,9 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
ToolChoice, None, Literal["auto", "any", "none", "required"]
] = None,
) -> "LLMStream":
return FallbackLLMStream(
llm=self,
Expand Down
19 changes: 18 additions & 1 deletion livekit-agents/livekit/agents/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from types import TracebackType
from typing import Any, AsyncIterable, AsyncIterator, Generic, Literal, TypeVar, Union
from typing import (
Any,
AsyncIterable,
AsyncIterator,
Generic,
Literal,
TypedDict,
TypeVar,
Union,
)

from livekit import rtc
from livekit.agents._exceptions import APIConnectionError, APIError
Expand Down Expand Up @@ -51,6 +60,11 @@ class ChatChunk:
usage: CompletionUsage | None = None


class ToolChoice(TypedDict, total=False):
type: Literal["auto", "any", "tool", "none", "required"]
name: str


TEvent = TypeVar("TEvent")


Expand Down Expand Up @@ -78,6 +92,9 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
ToolChoice, None, Literal["auto", "any", "none", "required"]
] = None,
) -> "LLMStream": ...

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,17 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Awaitable, List, Tuple, get_args, get_origin
from typing import (
Any,
Awaitable,
List,
Literal,
Tuple,
TypedDict,
Union,
get_args,
get_origin,
)

import httpx
from livekit import rtc
Expand All @@ -40,11 +50,18 @@
)


class ToolChoice(TypedDict, total=False):
type: Literal["auto", "any", "tool", "none", "required"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the difference between tool and required?

What is any used for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

required (OpenAI specific param) - the model must call one or more tools.
any (Anthropic specific param) - the model must call one or more tools.
tool (Anthropic specific param) - the model must call specific tool

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merge any and required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can, but I think we may need to document this otherwise, it might confuse users.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated code accordingly

name: str # Optional: only used when type is "tool"


@dataclass
class LLMOptions:
model: str | ChatModels
user: str | None
temperature: float | None
parallel_tool_calls: bool | None
tool_choice: Union[ToolChoice, None, Literal["auto", "any", "none", "required"]]


class LLM(llm.LLM):
Expand All @@ -57,6 +74,10 @@ def __init__(
user: str | None = None,
client: anthropic.AsyncClient | None = None,
temperature: float | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
ToolChoice, None, Literal["auto", "any", "none", "required"]
] = None,
) -> None:
"""
Create a new instance of Anthropic LLM.
Expand All @@ -71,7 +92,13 @@ def __init__(
if api_key is None:
raise ValueError("Anthropic API key is required")

self._opts = LLMOptions(model=model, user=user, temperature=temperature)
self._opts = LLMOptions(
model=model,
user=user,
temperature=temperature,
parallel_tool_calls=parallel_tool_calls,
tool_choice=tool_choice,
)
self._client = client or anthropic.AsyncClient(
api_key=api_key,
base_url=base_url,
Expand All @@ -95,9 +122,16 @@ def chat(
temperature: float | None = None,
n: int | None = 1,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
ToolChoice, None, Literal["auto", "any", "none", "required"]
] = None,
) -> "LLMStream":
if temperature is None:
temperature = self._opts.temperature
if parallel_tool_calls is None:
parallel_tool_calls = self._opts.parallel_tool_calls
if tool_choice is None:
tool_choice = self._opts.tool_choice

opts: dict[str, Any] = dict()
if fnc_ctx and len(fnc_ctx.ai_functions) > 0:
Expand All @@ -106,9 +140,23 @@ def chat(
fncs_desc.append(_build_function_description(fnc))

opts["tools"] = fncs_desc

if fnc_ctx and parallel_tool_calls is not None:
opts["parallel_tool_calls"] = parallel_tool_calls
anthropic_tool_choice: dict[str, Any] = {"type": "auto"}
if isinstance(tool_choice, dict):
if tool_choice["type"] == "tool":
anthropic_tool_choice = {
"type": "tool",
"name": tool_choice["name"],
}
elif tool_choice["type"] in ["any", "required"]:
anthropic_tool_choice = {"type": "any"}
elif isinstance(tool_choice, str):
if tool_choice in ["any", "required"]:
anthropic_tool_choice = {"type": "any"}
elif tool_choice == "none":
anthropic_tool_choice = {"type": "auto"}
if parallel_tool_calls is not None and parallel_tool_calls is False:
anthropic_tool_choice["disable_parallel_tool_use"] = True
opts["tool_choice"] = anthropic_tool_choice

latest_system_message = _latest_system_message(chat_ctx)
anthropic_ctx = _build_anthropic_context(chat_ctx.messages, id(self))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, MutableSet
from typing import Any, Callable, Dict, Literal, MutableSet, TypedDict, Union

import httpx
from livekit import rtc
Expand Down Expand Up @@ -53,6 +53,11 @@ class LLMOptions:
model: str | ChatModels


class ToolChoice(TypedDict, total=False):
type: Literal["auto", "any", "tool", "none", "required"]
name: str # Optional: only used when type is "tool"


@dataclass
class AssistantOptions:
"""Options for creating (on-the-fly) or loading an assistant. Only one of create_options or load_options should be set."""
Expand Down Expand Up @@ -172,6 +177,9 @@ def chat(
temperature: float | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
tool_choice: Union[
ToolChoice, None, Literal["auto", "any", "none", "required"]
] = None,
):
if n is not None:
logger.warning("OpenAI Assistants does not support the 'n' parameter")
Expand Down
Loading
Loading