From 45affd17a2ebddf767c7ab6fc0c58db9d9237425 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Wed, 1 Jan 2025 08:29:02 -0800 Subject: [PATCH] add litellm --- .../extension/openai_chatgpt_python/openai.py | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/agents/ten_packages/extension/openai_chatgpt_python/openai.py b/agents/ten_packages/extension/openai_chatgpt_python/openai.py index 8c4845ea..771cdd85 100644 --- a/agents/ten_packages/extension/openai_chatgpt_python/openai.py +++ b/agents/ten_packages/extension/openai_chatgpt_python/openai.py @@ -9,7 +9,7 @@ from dataclasses import dataclass import random import requests -from openai import AsyncOpenAI, AsyncAzureOpenAI +import litellm from openai.types.chat.chat_completion import ChatCompletion from ten.async_ten_env import AsyncTenEnv @@ -46,17 +46,6 @@ class OpenAIChatGPT: def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig): self.config = config ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}") - if self.config.vendor == "azure": - self.client = AsyncAzureOpenAI( - api_key=config.api_key, - api_version=self.config.azure_api_version, - azure_endpoint=config.azure_endpoint, - ) - ten_env.log_info( - f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}" - ) - else: - self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) self.session = requests.Session() if config.proxy_url: proxies = { @@ -69,7 +58,7 @@ def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig): async def get_chat_completions(self, messages, tools=None) -> ChatCompletion: req = { - "model": self.config.model, + "model": f"{self.config.vendor}/{self.config.model}", "messages": [ { "role": "system", @@ -87,7 +76,7 @@ async def get_chat_completions(self, messages, tools=None) -> ChatCompletion: } try: - response = await self.client.chat.completions.create(**req) + response = await litellm.acompletion(**req) except Exception as e: raise RuntimeError(f"CreateChatCompletion failed, err: {e}") from e @@ -95,7 +84,7 @@ async def get_chat_completions(self, messages, tools=None) -> ChatCompletion: async def get_chat_completions_stream(self, messages, tools=None, listener=None): req = { - "model": self.config.model, + "model": f"{self.config.vendor}/{self.config.model}", "messages": [ { "role": "system", @@ -114,7 +103,7 @@ async def get_chat_completions_stream(self, messages, tools=None, listener=None) } try: - response = await self.client.chat.completions.create(**req) + response = await litellm.acompletion(**req) except Exception as e: raise RuntimeError(f"CreateChatCompletionStream failed, err: {e}") from e