Skip to content

Commit

Permalink
gemini improvements: exception handling, transcription (#1398)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayeshp19 authored Jan 21, 2025
1 parent 57c59f2 commit b4b593d
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 33 deletions.
5 changes: 5 additions & 0 deletions .changeset/purple-years-deny.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

gemini improvements: exception handling, transcription & Ensure contents.parts is non-empty in gemini contex
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ def _build_gemini_ctx(
chat_ctx: llm.ChatContext, cache_key: Any
) -> tuple[list[types.Content], Optional[types.Content]]:
turns: list[types.Content] = []
current_content: Optional[types.Content] = None
system_instruction: Optional[types.Content] = None
current_role: Optional[str] = None
parts: list[types.Part] = []

for msg in chat_ctx.messages:
if msg.role == "system":
Expand All @@ -107,18 +107,16 @@ def _build_gemini_ctx(
else:
role = "user"

# Start new turn if role changes or if none is set
if current_content is None or current_role != role:
current_content = types.Content(role=role, parts=[])
turns.append(current_content)
# If role changed, finalize previous parts into a turn
if role != current_role:
if current_role is not None and parts:
turns.append(types.Content(role=current_role, parts=parts))
current_role = role

if current_content.parts is None:
current_content.parts = []
parts = []

if msg.tool_calls:
for fnc in msg.tool_calls:
current_content.parts.append(
parts.append(
types.Part(
function_call=types.FunctionCall(
id=fnc.tool_call_id,
Expand All @@ -131,7 +129,7 @@ def _build_gemini_ctx(
if msg.role == "tool":
if msg.content:
if isinstance(msg.content, dict):
current_content.parts.append(
parts.append(
types.Part(
function_response=types.FunctionResponse(
id=msg.tool_call_id,
Expand All @@ -141,7 +139,7 @@ def _build_gemini_ctx(
)
)
elif isinstance(msg.content, str):
current_content.parts.append(
parts.append(
types.Part(
function_response=types.FunctionResponse(
id=msg.tool_call_id,
Expand All @@ -153,19 +151,19 @@ def _build_gemini_ctx(
else:
if msg.content:
if isinstance(msg.content, str):
current_content.parts.append(types.Part(text=msg.content))
parts.append(types.Part(text=msg.content))
elif isinstance(msg.content, dict):
current_content.parts.append(
types.Part(text=json.dumps(msg.content))
)
parts.append(types.Part(text=json.dumps(msg.content)))
elif isinstance(msg.content, list):
for item in msg.content:
if isinstance(item, str):
current_content.parts.append(types.Part(text=item))
parts.append(types.Part(text=item))
elif isinstance(item, llm.ChatImage):
current_content.parts.append(
_build_gemini_image_part(item, cache_key)
)
parts.append(_build_gemini_image_part(item, cache_key))

# Finalize last role's parts if any remain
if current_role is not None and parts:
turns.append(types.Content(role=current_role, parts=parts))

return turns, system_instruction

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,19 +379,20 @@ def server_vad_enabled(self) -> bool:
return True

def _on_input_speech_done(self, content: TranscriptionContent) -> None:
self.emit(
"input_speech_transcription_completed",
InputTranscription(
item_id=content.response_id,
transcript=content.text,
),
)
if content.response_id and content.text:
self.emit(
"input_speech_transcription_completed",
InputTranscription(
item_id=content.response_id,
transcript=content.text,
),
)

# self._chat_ctx.append(text=content.text, role="user")
# TODO: implement sync mechanism to make sure the transcribed user speech is inside the chat_ctx and always before the generated agent speech

def _on_agent_speech_done(self, content: TranscriptionContent) -> None:
if not self._is_interrupted:
if not self._is_interrupted and content.response_id and content.text:
self.emit(
"agent_speech_transcription_completed",
InputTranscription(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import re
from dataclasses import dataclass
from typing import Literal

Expand Down Expand Up @@ -135,6 +136,7 @@ async def _recv_task():
content.text += part.text

if server_content.turn_complete:
content.text = clean_transcription(content.text)
self.emit("input_speech_done", content)
self._active_response_id = None

Expand Down Expand Up @@ -163,3 +165,9 @@ async def _recv_task():
finally:
await utils.aio.gracefully_cancel(*tasks)
await self._session.close()


def clean_transcription(text: str) -> str:
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
return text.strip()
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from google import genai
from google.auth._default_async import default_async
from google.genai import types
from google.genai.errors import APIError, ClientError, ServerError

from ._utils import _build_gemini_ctx, _build_tools
from .log import logger
Expand Down Expand Up @@ -220,6 +221,7 @@ def __init__(

async def _run(self) -> None:
retryable = True
request_id = utils.shortuuid()

try:
opts: dict[str, Any] = dict()
Expand Down Expand Up @@ -281,12 +283,11 @@ async def _run(self) -> None:
contents=cast(types.ContentListUnion, turns),
config=config,
):
response_id = utils.shortuuid()
if response.prompt_feedback:
raise APIStatusError(
response.prompt_feedback.json(),
retryable=False,
request_id=response_id,
request_id=request_id,
)

if (
Expand All @@ -297,7 +298,7 @@ async def _run(self) -> None:
raise APIStatusError(
"No candidates in the response",
retryable=True,
request_id=response_id,
request_id=request_id,
)

if len(response.candidates) > 1:
Expand All @@ -306,7 +307,7 @@ async def _run(self) -> None:
)

for index, part in enumerate(response.candidates[0].content.parts):
chat_chunk = self._parse_part(response_id, index, part)
chat_chunk = self._parse_part(request_id, index, part)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)
Expand All @@ -315,15 +316,38 @@ async def _run(self) -> None:
usage = response.usage_metadata
self._event_ch.send_nowait(
llm.ChatChunk(
request_id=response_id,
request_id=request_id,
usage=llm.CompletionUsage(
completion_tokens=usage.candidates_token_count or 0,
prompt_tokens=usage.prompt_token_count or 0,
total_tokens=usage.total_token_count or 0,
),
)
)

except ClientError as e:
raise APIStatusError(
"gemini llm: client error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=False if e.code != 429 else True,
) from e
except ServerError as e:
raise APIStatusError(
"gemini llm: server error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=retryable,
) from e
except APIError as e:
raise APIStatusError(
"gemini llm: api error",
status_code=e.code,
body=e.message,
request_id=request_id,
retryable=retryable,
) from e
except Exception as e:
raise APIConnectionError(
"gemini llm: error generating content",
Expand Down

0 comments on commit b4b593d

Please sign in to comment.