Skip to content

Commit 7a92c95

Browse files
feat(): add gemini llm extension (#199)
1 parent 3e348b6 commit 7a92c95

12 files changed

+658
-1
lines changed

.env.example

+4
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ COSY_TTS_KEY=
5151
# ElevenLabs TTS key
5252
ELEVENLABS_TTS_KEY=
5353

54+
# Extension: gemini_llm
55+
# Gemini API key
56+
GEMINI_API_KEY=
57+
5458
# Extension: litellm
5559
# Using Environment Variables, refer to https://docs.litellm.ai/docs/providers
5660
# For example:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from . import gemini_llm_addon
2+
from .extension import EXTENSION_NAME
3+
from .log import logger
4+
5+
6+
logger.info(f"{EXTENSION_NAME} extension loaded")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
EXTENSION_NAME = "gemini_llm_python"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from typing import Dict, List
2+
import google.generativeai as genai
3+
4+
5+
class GeminiLLMConfig:
6+
def __init__(self,
7+
api_key: str,
8+
max_output_tokens: int,
9+
model: str,
10+
prompt: str,
11+
temperature: float,
12+
top_k: int,
13+
top_p: float):
14+
self.api_key = api_key
15+
self.max_output_tokens = max_output_tokens
16+
self.model = model
17+
self.prompt = prompt
18+
self.temperature = temperature
19+
self.top_k = top_k
20+
self.top_p = top_p
21+
22+
@classmethod
23+
def default_config(cls):
24+
return cls(
25+
api_key="",
26+
max_output_tokens=512,
27+
model="gemini-1.0-pro-latest",
28+
prompt="You are a voice assistant who talks in a conversational way and can chat with me like my friends. I will speak to you in English or Chinese, and you will answer in the corrected and improved version of my text with the language I use. Don’t talk like a robot, instead I would like you to talk like a real human with emotions. I will use your answer for text-to-speech, so don’t return me any meaningless characters. I want you to be helpful, when I’m asking you for advice, give me precise, practical and useful advice instead of being vague. When giving me a list of options, express the options in a narrative way instead of bullet points.",
29+
temperature=0.1,
30+
top_k=40,
31+
top_p=0.95,
32+
)
33+
34+
35+
class GeminiLLM:
36+
def __init__(self, config: GeminiLLMConfig):
37+
self.config = config
38+
genai.configure(api_key=self.config.api_key)
39+
self.model = genai.GenerativeModel(self.config.model)
40+
41+
def get_chat_completions_stream(self, messages: List[Dict[str, str]]):
42+
try:
43+
chat = self.model.start_chat(history=messages[0:-1])
44+
response = chat.send_message((self.config.prompt, messages[-1].get("parts")),
45+
generation_config=genai.types.GenerationConfig(
46+
max_output_tokens=self.config.max_output_tokens,
47+
temperature=self.config.temperature,
48+
top_k=self.config.top_k,
49+
top_p=self.config.top_p),
50+
stream=True)
51+
52+
return response
53+
except Exception as e:
54+
raise Exception(f"get_chat_completions_stream failed, err: {e}")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#
2+
#
3+
# Agora Real Time Engagement
4+
# Created by XinHui Li in 2024.
5+
# Copyright (c) 2024 Agora IO. All rights reserved.
6+
#
7+
#
8+
from rte import (
9+
Addon,
10+
register_addon_as_extension,
11+
RteEnv,
12+
)
13+
from .extension import EXTENSION_NAME
14+
from .log import logger
15+
from .gemini_llm_extension import GeminiLLMExtension
16+
17+
18+
@register_addon_as_extension(EXTENSION_NAME)
19+
class GeminiLLMExtensionAddon(Addon):
20+
def on_create_instance(self, rte: RteEnv, addon_name: str, context) -> None:
21+
logger.info("on_create_instance")
22+
23+
rte.on_create_instance_done(GeminiLLMExtension(addon_name), context)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
#
2+
#
3+
# Agora Real Time Engagement
4+
# Created by XinHui Li in 2024.
5+
# Copyright (c) 2024 Agora IO. All rights reserved.
6+
#
7+
#
8+
from threading import Thread
9+
from rte import (
10+
Extension,
11+
RteEnv,
12+
Cmd,
13+
Data,
14+
StatusCode,
15+
CmdResult,
16+
)
17+
from .gemini_llm import GeminiLLM, GeminiLLMConfig
18+
from .log import logger
19+
from .utils import get_micro_ts, parse_sentence
20+
21+
22+
CMD_IN_FLUSH = "flush"
23+
CMD_OUT_FLUSH = "flush"
24+
DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text"
25+
DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final"
26+
DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text"
27+
DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment"
28+
29+
PROPERTY_API_KEY = "api_key" # Required
30+
PROPERTY_GREETING = "greeting" # Optional
31+
PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" # Optional
32+
PROPERTY_MAX_OUTPUT_TOKENS = "max_output_tokens" # Optional
33+
PROPERTY_MODEL = "model" # Optional
34+
PROPERTY_PROMPT = "prompt" # Optional
35+
PROPERTY_TEMPERATURE = "temperature" # Optional
36+
PROPERTY_TOP_K = "top_k" # Optional
37+
PROPERTY_TOP_P = "top_p" # Optional
38+
39+
40+
class GeminiLLMExtension(Extension):
41+
memory = []
42+
max_memory_length = 10
43+
outdate_ts = 0
44+
gemini_llm = None
45+
46+
def on_start(self, rte: RteEnv) -> None:
47+
logger.info("GeminiLLMExtension on_start")
48+
# Prepare configuration
49+
gemini_llm_config = GeminiLLMConfig.default_config()
50+
51+
try:
52+
api_key = rte.get_property_string(PROPERTY_API_KEY)
53+
gemini_llm_config.api_key = api_key
54+
except Exception as err:
55+
logger.info(f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}")
56+
return
57+
58+
for key in [PROPERTY_GREETING, PROPERTY_MODEL, PROPERTY_PROMPT]:
59+
try:
60+
val = rte.get_property_string(key)
61+
if val:
62+
gemini_llm_config.key = val
63+
except Exception as e:
64+
logger.warning(f"get_property_string optional {key} failed, err: {e}")
65+
66+
for key in [PROPERTY_TEMPERATURE, PROPERTY_TOP_P]:
67+
try:
68+
gemini_llm_config.key = float(rte.get_property_float(key))
69+
except Exception as e:
70+
logger.warning(f"get_property_float optional {key} failed, err: {e}")
71+
72+
for key in [PROPERTY_MAX_OUTPUT_TOKENS, PROPERTY_TOP_K]:
73+
try:
74+
gemini_llm_config.key = int(rte.get_property_int(key))
75+
except Exception as e:
76+
logger.warning(f"get_property_int optional {key} failed, err: {e}")
77+
78+
try:
79+
prop_max_memory_length = rte.get_property_int(PROPERTY_MAX_MEMORY_LENGTH)
80+
if prop_max_memory_length > 0:
81+
self.max_memory_length = int(prop_max_memory_length)
82+
except Exception as err:
83+
logger.warning(f"GetProperty optional {PROPERTY_MAX_MEMORY_LENGTH} failed, err: {err}")
84+
85+
# Create GeminiLLM instance
86+
self.gemini_llm = GeminiLLM(gemini_llm_config)
87+
logger.info(f"newGeminiLLM succeed with max_output_tokens: {gemini_llm_config.max_output_tokens}, model: {gemini_llm_config.model}")
88+
89+
# Send greeting if available
90+
greeting = rte.get_property_string(PROPERTY_GREETING)
91+
if greeting:
92+
try:
93+
output_data = Data.create("text_data")
94+
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting)
95+
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
96+
rte.send_data(output_data)
97+
logger.info(f"greeting [{greeting}] sent")
98+
except Exception as e:
99+
logger.error(f"greeting [{greeting}] send failed, err: {e}")
100+
101+
rte.on_start_done()
102+
103+
def on_stop(self, rte: RteEnv) -> None:
104+
logger.info("GeminiLLMExtension on_stop")
105+
rte.on_stop_done()
106+
107+
def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None:
108+
logger.info("GeminiLLMExtension on_cmd")
109+
cmd_json = cmd.to_json()
110+
logger.info(f"GeminiLLMExtension on_cmd json: {cmd_json}")
111+
112+
cmd_name = cmd.get_name()
113+
114+
if cmd_name == CMD_IN_FLUSH:
115+
self.outdate_ts = get_micro_ts()
116+
cmd_out = Cmd.create(CMD_OUT_FLUSH)
117+
rte.send_cmd(cmd_out, None)
118+
logger.info(f"GeminiLLMExtension on_cmd sent flush")
119+
else:
120+
logger.info(f"GeminiLLMExtension on_cmd unknown cmd: {cmd_name}")
121+
cmd_result = CmdResult.create(StatusCode.ERROR)
122+
cmd_result.set_property_string("detail", "unknown cmd")
123+
rte.return_result(cmd_result, cmd)
124+
return
125+
126+
cmd_result = CmdResult.create(StatusCode.OK)
127+
cmd_result.set_property_string("detail", "success")
128+
rte.return_result(cmd_result, cmd)
129+
130+
def on_data(self, rte: RteEnv, data: Data) -> None:
131+
"""
132+
on_data receives data from rte graph.
133+
current supported data:
134+
- name: text_data
135+
example:
136+
{name: text_data, properties: {text: "hello"}
137+
"""
138+
logger.info(f"GeminiLLMExtension on_data")
139+
140+
# Assume 'data' is an object from which we can get properties
141+
try:
142+
is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL)
143+
if not is_final:
144+
logger.info("ignore non-final input")
145+
return
146+
except Exception as e:
147+
logger.error(f"on_data get_property_bool {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {e}")
148+
return
149+
150+
# Get input text
151+
try:
152+
input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT)
153+
if not input_text:
154+
logger.info("ignore empty text")
155+
return
156+
logger.info(f"on_data input text: [{input_text}]")
157+
except Exception as e:
158+
logger.error(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {e}")
159+
return
160+
161+
# Prepare memory
162+
if len(self.memory) > self.max_memory_length:
163+
self.memory.pop(0)
164+
self.memory.append({"role": "user", "parts": input_text})
165+
166+
def chat_completions_stream_worker(start_time, input_text, memory):
167+
try:
168+
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] memory: {memory}")
169+
170+
# Get result from AI
171+
resp = self.gemini_llm.get_chat_completions_stream(memory)
172+
if resp is None:
173+
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] failed")
174+
return
175+
176+
sentence = ""
177+
full_content = ""
178+
first_sentence_sent = False
179+
180+
for chat_completions in resp:
181+
if start_time < self.outdate_ts:
182+
logger.info(f"chat_completions_stream_worker recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}")
183+
break
184+
185+
if (chat_completions.text is not None):
186+
content = chat_completions.text
187+
else:
188+
content = ""
189+
190+
full_content += content
191+
192+
while True:
193+
sentence, content, sentence_is_final = parse_sentence(sentence, content)
194+
195+
if len(sentence) == 0 or not sentence_is_final:
196+
logger.info(f"sentence {sentence} is empty or not final")
197+
break
198+
199+
logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] got sentence: [{sentence}]")
200+
201+
# send sentence
202+
try:
203+
output_data = Data.create("text_data")
204+
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
205+
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False)
206+
rte.send_data(output_data)
207+
logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] sent sentence [{sentence}]")
208+
except Exception as e:
209+
logger.error(f"chat_completions_stream_worker recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {e}")
210+
break
211+
212+
sentence = ""
213+
if not first_sentence_sent:
214+
first_sentence_sent = True
215+
logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] first sentence sent, first_sentence_latency {get_micro_ts() - start_time}ms")
216+
217+
# remember response as assistant content in memory
218+
memory.append({"role": "model", "parts": full_content})
219+
220+
# send end of segment
221+
try:
222+
output_data = Data.create("text_data")
223+
output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence)
224+
output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True)
225+
rte.send_data(output_data)
226+
logger.info(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] sent")
227+
except Exception as e:
228+
logger.error(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {e}")
229+
230+
except Exception as e:
231+
logger.error(f"chat_completions_stream_worker for input text: [{input_text}] failed, err: {e}")
232+
233+
# Start thread to request and read responses from GeminiLLM
234+
start_time = get_micro_ts()
235+
thread = Thread(
236+
target=chat_completions_stream_worker,
237+
args=(start_time, input_text, self.memory),
238+
)
239+
thread.start()
240+
logger.info(f"GeminiLLMExtension on_data end")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import logging
2+
from .extension import EXTENSION_NAME
3+
4+
logger = logging.getLogger(EXTENSION_NAME)
5+
logger.setLevel(logging.INFO)
6+
7+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(process)d - [%(filename)s:%(lineno)d] - %(message)s")
8+
9+
console_handler = logging.StreamHandler()
10+
console_handler.setFormatter(formatter)
11+
12+
logger.addHandler(console_handler)

0 commit comments

Comments
 (0)