|
| 1 | +# |
| 2 | +# |
| 3 | +# Agora Real Time Engagement |
| 4 | +# Created by XinHui Li in 2024. |
| 5 | +# Copyright (c) 2024 Agora IO. All rights reserved. |
| 6 | +# |
| 7 | +# |
| 8 | +from threading import Thread |
| 9 | +from rte import ( |
| 10 | + Extension, |
| 11 | + RteEnv, |
| 12 | + Cmd, |
| 13 | + Data, |
| 14 | + StatusCode, |
| 15 | + CmdResult, |
| 16 | +) |
| 17 | +from .gemini_llm import GeminiLLM, GeminiLLMConfig |
| 18 | +from .log import logger |
| 19 | +from .utils import get_micro_ts, parse_sentence |
| 20 | + |
| 21 | + |
| 22 | +CMD_IN_FLUSH = "flush" |
| 23 | +CMD_OUT_FLUSH = "flush" |
| 24 | +DATA_IN_TEXT_DATA_PROPERTY_TEXT = "text" |
| 25 | +DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
| 26 | +DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
| 27 | +DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT = "end_of_segment" |
| 28 | + |
| 29 | +PROPERTY_API_KEY = "api_key" # Required |
| 30 | +PROPERTY_GREETING = "greeting" # Optional |
| 31 | +PROPERTY_MAX_MEMORY_LENGTH = "max_memory_length" # Optional |
| 32 | +PROPERTY_MAX_OUTPUT_TOKENS = "max_output_tokens" # Optional |
| 33 | +PROPERTY_MODEL = "model" # Optional |
| 34 | +PROPERTY_PROMPT = "prompt" # Optional |
| 35 | +PROPERTY_TEMPERATURE = "temperature" # Optional |
| 36 | +PROPERTY_TOP_K = "top_k" # Optional |
| 37 | +PROPERTY_TOP_P = "top_p" # Optional |
| 38 | + |
| 39 | + |
| 40 | +class GeminiLLMExtension(Extension): |
| 41 | + memory = [] |
| 42 | + max_memory_length = 10 |
| 43 | + outdate_ts = 0 |
| 44 | + gemini_llm = None |
| 45 | + |
| 46 | + def on_start(self, rte: RteEnv) -> None: |
| 47 | + logger.info("GeminiLLMExtension on_start") |
| 48 | + # Prepare configuration |
| 49 | + gemini_llm_config = GeminiLLMConfig.default_config() |
| 50 | + |
| 51 | + try: |
| 52 | + api_key = rte.get_property_string(PROPERTY_API_KEY) |
| 53 | + gemini_llm_config.api_key = api_key |
| 54 | + except Exception as err: |
| 55 | + logger.info(f"GetProperty required {PROPERTY_API_KEY} failed, err: {err}") |
| 56 | + return |
| 57 | + |
| 58 | + for key in [PROPERTY_GREETING, PROPERTY_MODEL, PROPERTY_PROMPT]: |
| 59 | + try: |
| 60 | + val = rte.get_property_string(key) |
| 61 | + if val: |
| 62 | + gemini_llm_config.key = val |
| 63 | + except Exception as e: |
| 64 | + logger.warning(f"get_property_string optional {key} failed, err: {e}") |
| 65 | + |
| 66 | + for key in [PROPERTY_TEMPERATURE, PROPERTY_TOP_P]: |
| 67 | + try: |
| 68 | + gemini_llm_config.key = float(rte.get_property_float(key)) |
| 69 | + except Exception as e: |
| 70 | + logger.warning(f"get_property_float optional {key} failed, err: {e}") |
| 71 | + |
| 72 | + for key in [PROPERTY_MAX_OUTPUT_TOKENS, PROPERTY_TOP_K]: |
| 73 | + try: |
| 74 | + gemini_llm_config.key = int(rte.get_property_int(key)) |
| 75 | + except Exception as e: |
| 76 | + logger.warning(f"get_property_int optional {key} failed, err: {e}") |
| 77 | + |
| 78 | + try: |
| 79 | + prop_max_memory_length = rte.get_property_int(PROPERTY_MAX_MEMORY_LENGTH) |
| 80 | + if prop_max_memory_length > 0: |
| 81 | + self.max_memory_length = int(prop_max_memory_length) |
| 82 | + except Exception as err: |
| 83 | + logger.warning(f"GetProperty optional {PROPERTY_MAX_MEMORY_LENGTH} failed, err: {err}") |
| 84 | + |
| 85 | + # Create GeminiLLM instance |
| 86 | + self.gemini_llm = GeminiLLM(gemini_llm_config) |
| 87 | + logger.info(f"newGeminiLLM succeed with max_output_tokens: {gemini_llm_config.max_output_tokens}, model: {gemini_llm_config.model}") |
| 88 | + |
| 89 | + # Send greeting if available |
| 90 | + greeting = rte.get_property_string(PROPERTY_GREETING) |
| 91 | + if greeting: |
| 92 | + try: |
| 93 | + output_data = Data.create("text_data") |
| 94 | + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, greeting) |
| 95 | + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) |
| 96 | + rte.send_data(output_data) |
| 97 | + logger.info(f"greeting [{greeting}] sent") |
| 98 | + except Exception as e: |
| 99 | + logger.error(f"greeting [{greeting}] send failed, err: {e}") |
| 100 | + |
| 101 | + rte.on_start_done() |
| 102 | + |
| 103 | + def on_stop(self, rte: RteEnv) -> None: |
| 104 | + logger.info("GeminiLLMExtension on_stop") |
| 105 | + rte.on_stop_done() |
| 106 | + |
| 107 | + def on_cmd(self, rte: RteEnv, cmd: Cmd) -> None: |
| 108 | + logger.info("GeminiLLMExtension on_cmd") |
| 109 | + cmd_json = cmd.to_json() |
| 110 | + logger.info(f"GeminiLLMExtension on_cmd json: {cmd_json}") |
| 111 | + |
| 112 | + cmd_name = cmd.get_name() |
| 113 | + |
| 114 | + if cmd_name == CMD_IN_FLUSH: |
| 115 | + self.outdate_ts = get_micro_ts() |
| 116 | + cmd_out = Cmd.create(CMD_OUT_FLUSH) |
| 117 | + rte.send_cmd(cmd_out, None) |
| 118 | + logger.info(f"GeminiLLMExtension on_cmd sent flush") |
| 119 | + else: |
| 120 | + logger.info(f"GeminiLLMExtension on_cmd unknown cmd: {cmd_name}") |
| 121 | + cmd_result = CmdResult.create(StatusCode.ERROR) |
| 122 | + cmd_result.set_property_string("detail", "unknown cmd") |
| 123 | + rte.return_result(cmd_result, cmd) |
| 124 | + return |
| 125 | + |
| 126 | + cmd_result = CmdResult.create(StatusCode.OK) |
| 127 | + cmd_result.set_property_string("detail", "success") |
| 128 | + rte.return_result(cmd_result, cmd) |
| 129 | + |
| 130 | + def on_data(self, rte: RteEnv, data: Data) -> None: |
| 131 | + """ |
| 132 | + on_data receives data from rte graph. |
| 133 | + current supported data: |
| 134 | + - name: text_data |
| 135 | + example: |
| 136 | + {name: text_data, properties: {text: "hello"} |
| 137 | + """ |
| 138 | + logger.info(f"GeminiLLMExtension on_data") |
| 139 | + |
| 140 | + # Assume 'data' is an object from which we can get properties |
| 141 | + try: |
| 142 | + is_final = data.get_property_bool(DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL) |
| 143 | + if not is_final: |
| 144 | + logger.info("ignore non-final input") |
| 145 | + return |
| 146 | + except Exception as e: |
| 147 | + logger.error(f"on_data get_property_bool {DATA_IN_TEXT_DATA_PROPERTY_IS_FINAL} failed, err: {e}") |
| 148 | + return |
| 149 | + |
| 150 | + # Get input text |
| 151 | + try: |
| 152 | + input_text = data.get_property_string(DATA_IN_TEXT_DATA_PROPERTY_TEXT) |
| 153 | + if not input_text: |
| 154 | + logger.info("ignore empty text") |
| 155 | + return |
| 156 | + logger.info(f"on_data input text: [{input_text}]") |
| 157 | + except Exception as e: |
| 158 | + logger.error(f"on_data get_property_string {DATA_IN_TEXT_DATA_PROPERTY_TEXT} failed, err: {e}") |
| 159 | + return |
| 160 | + |
| 161 | + # Prepare memory |
| 162 | + if len(self.memory) > self.max_memory_length: |
| 163 | + self.memory.pop(0) |
| 164 | + self.memory.append({"role": "user", "parts": input_text}) |
| 165 | + |
| 166 | + def chat_completions_stream_worker(start_time, input_text, memory): |
| 167 | + try: |
| 168 | + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] memory: {memory}") |
| 169 | + |
| 170 | + # Get result from AI |
| 171 | + resp = self.gemini_llm.get_chat_completions_stream(memory) |
| 172 | + if resp is None: |
| 173 | + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] failed") |
| 174 | + return |
| 175 | + |
| 176 | + sentence = "" |
| 177 | + full_content = "" |
| 178 | + first_sentence_sent = False |
| 179 | + |
| 180 | + for chat_completions in resp: |
| 181 | + if start_time < self.outdate_ts: |
| 182 | + logger.info(f"chat_completions_stream_worker recv interrupt and flushing for input text: [{input_text}], startTs: {start_time}, outdateTs: {self.outdate_ts}") |
| 183 | + break |
| 184 | + |
| 185 | + if (chat_completions.text is not None): |
| 186 | + content = chat_completions.text |
| 187 | + else: |
| 188 | + content = "" |
| 189 | + |
| 190 | + full_content += content |
| 191 | + |
| 192 | + while True: |
| 193 | + sentence, content, sentence_is_final = parse_sentence(sentence, content) |
| 194 | + |
| 195 | + if len(sentence) == 0 or not sentence_is_final: |
| 196 | + logger.info(f"sentence {sentence} is empty or not final") |
| 197 | + break |
| 198 | + |
| 199 | + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] got sentence: [{sentence}]") |
| 200 | + |
| 201 | + # send sentence |
| 202 | + try: |
| 203 | + output_data = Data.create("text_data") |
| 204 | + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) |
| 205 | + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, False) |
| 206 | + rte.send_data(output_data) |
| 207 | + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] sent sentence [{sentence}]") |
| 208 | + except Exception as e: |
| 209 | + logger.error(f"chat_completions_stream_worker recv for input text: [{input_text}] send sentence [{sentence}] failed, err: {e}") |
| 210 | + break |
| 211 | + |
| 212 | + sentence = "" |
| 213 | + if not first_sentence_sent: |
| 214 | + first_sentence_sent = True |
| 215 | + logger.info(f"chat_completions_stream_worker recv for input text: [{input_text}] first sentence sent, first_sentence_latency {get_micro_ts() - start_time}ms") |
| 216 | + |
| 217 | + # remember response as assistant content in memory |
| 218 | + memory.append({"role": "model", "parts": full_content}) |
| 219 | + |
| 220 | + # send end of segment |
| 221 | + try: |
| 222 | + output_data = Data.create("text_data") |
| 223 | + output_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, sentence) |
| 224 | + output_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_TEXT_END_OF_SEGMENT, True) |
| 225 | + rte.send_data(output_data) |
| 226 | + logger.info(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] sent") |
| 227 | + except Exception as e: |
| 228 | + logger.error(f"chat_completions_stream_worker for input text: [{input_text}] end of segment with sentence [{sentence}] send failed, err: {e}") |
| 229 | + |
| 230 | + except Exception as e: |
| 231 | + logger.error(f"chat_completions_stream_worker for input text: [{input_text}] failed, err: {e}") |
| 232 | + |
| 233 | + # Start thread to request and read responses from GeminiLLM |
| 234 | + start_time = get_micro_ts() |
| 235 | + thread = Thread( |
| 236 | + target=chat_completions_stream_worker, |
| 237 | + args=(start_time, input_text, self.memory), |
| 238 | + ) |
| 239 | + thread.start() |
| 240 | + logger.info(f"GeminiLLMExtension on_data end") |
0 commit comments