Skip to content

Commit

Permalink
fix: ensure llm.FallbackAdapter executes function calls correctly (#1409
Browse files Browse the repository at this point in the history
)
  • Loading branch information
davidzhao authored Jan 24, 2025
1 parent 695f7b5 commit 57ae9a2
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 8 deletions.
5 changes: 5 additions & 0 deletions .changeset/modern-seas-push.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-agents": patch
---

fix: ensure llm.FallbackAdapter executes function calls
13 changes: 11 additions & 2 deletions examples/voice-pipeline-agent/fallback_adapter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from datetime import datetime

from dotenv import load_dotenv
from livekit.agents import (
Expand All @@ -12,7 +13,7 @@
tts,
)
from livekit.agents.pipeline import VoicePipelineAgent
from livekit.plugins import cartesia, deepgram, elevenlabs, openai, silero
from livekit.plugins import cartesia, deepgram, openai, playai, silero

load_dotenv()
logger = logging.getLogger("fallback-adapter-example")
Expand All @@ -31,6 +32,13 @@ async def entrypoint(ctx: JobContext):
),
)

fnc_ctx = llm.FunctionContext()

@fnc_ctx.ai_callable()
def get_time():
"""called to retrieve the current local time"""
return datetime.now().strftime("%H:%M:%S")

await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)

# wait for the first participant to connect
Expand Down Expand Up @@ -60,7 +68,7 @@ async def entrypoint(ctx: JobContext):
fallback_tts = tts.FallbackAdapter(
[
cartesia.TTS(),
elevenlabs.TTS(),
playai.TTS(),
]
)

Expand All @@ -70,6 +78,7 @@ async def entrypoint(ctx: JobContext):
llm=fallback_llm,
tts=fallback_tts,
chat_ctx=initial_ctx,
fnc_ctx=fnc_ctx,
)

agent.start(ctx.room, participant)
Expand Down
36 changes: 30 additions & 6 deletions livekit-agents/livekit/agents/llm/fallback_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import dataclasses
import time
from dataclasses import dataclass
from typing import AsyncIterable, Literal, Union
from typing import AsyncIterable, Literal, Optional, Union

from livekit.agents._exceptions import APIConnectionError, APIError

from ..log import logger
from ..types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions
from .chat_context import ChatContext
from .function_context import FunctionContext
from .function_context import CalledFunction, FunctionCallInfo, FunctionContext
from .llm import LLM, ChatChunk, LLMStream, ToolChoice

DEFAULT_FALLBACK_API_CONNECT_OPTIONS = APIConnectOptions(
Expand Down Expand Up @@ -104,6 +104,31 @@ def __init__(
self._parallel_tool_calls = parallel_tool_calls
self._tool_choice = tool_choice

self._current_stream: Optional[LLMStream] = None

@property
def function_calls(self) -> list[FunctionCallInfo]:
if self._current_stream is None:
return []
return self._current_stream.function_calls

@property
def chat_ctx(self) -> ChatContext:
if self._current_stream is None:
return self._chat_ctx
return self._current_stream.chat_ctx

@property
def fnc_ctx(self) -> FunctionContext | None:
if self._current_stream is None:
return self._fnc_ctx
return self._current_stream.fnc_ctx

def execute_functions(self) -> list[CalledFunction]:
if self._current_stream is None:
return []
return self._current_stream.execute_functions()

async def _try_generate(
self, *, llm: LLM, recovering: bool = False
) -> AsyncIterable[ChatChunk]:
Expand All @@ -122,6 +147,7 @@ async def _try_generate(
retry_interval=self._fallback_adapter._retry_interval,
),
) as stream:
self._current_stream = stream
async for chunk in stream:
yield chunk

Expand Down Expand Up @@ -196,11 +222,9 @@ async def _run(self) -> None:
if llm_status.available or all_failed:
chunk_sent = False
try:
async for synthesized_audio in self._try_generate(
llm=llm, recovering=False
):
async for result in self._try_generate(llm=llm, recovering=False):
chunk_sent = True
self._event_ch.send_nowait(synthesized_audio)
self._event_ch.send_nowait(result)

return
except Exception: # exceptions already logged inside _try_synthesize
Expand Down

0 comments on commit 57ae9a2

Please sign in to comment.