Skip to content

Commit 45affd1

Browse files
committed
add litellm
1 parent 6c873a3 commit 45affd1

File tree

1 file changed

+5
-16
lines changed
  • agents/ten_packages/extension/openai_chatgpt_python

1 file changed

+5
-16
lines changed

agents/ten_packages/extension/openai_chatgpt_python/openai.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from dataclasses import dataclass
1010
import random
1111
import requests
12-
from openai import AsyncOpenAI, AsyncAzureOpenAI
12+
import litellm
1313
from openai.types.chat.chat_completion import ChatCompletion
1414

1515
from ten.async_ten_env import AsyncTenEnv
@@ -46,17 +46,6 @@ class OpenAIChatGPT:
4646
def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig):
4747
self.config = config
4848
ten_env.log_info(f"OpenAIChatGPT initialized with config: {config.api_key}")
49-
if self.config.vendor == "azure":
50-
self.client = AsyncAzureOpenAI(
51-
api_key=config.api_key,
52-
api_version=self.config.azure_api_version,
53-
azure_endpoint=config.azure_endpoint,
54-
)
55-
ten_env.log_info(
56-
f"Using Azure OpenAI with endpoint: {config.azure_endpoint}, api_version: {config.azure_api_version}"
57-
)
58-
else:
59-
self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url)
6049
self.session = requests.Session()
6150
if config.proxy_url:
6251
proxies = {
@@ -69,7 +58,7 @@ def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig):
6958

7059
async def get_chat_completions(self, messages, tools=None) -> ChatCompletion:
7160
req = {
72-
"model": self.config.model,
61+
"model": f"{self.config.vendor}/{self.config.model}",
7362
"messages": [
7463
{
7564
"role": "system",
@@ -87,15 +76,15 @@ async def get_chat_completions(self, messages, tools=None) -> ChatCompletion:
8776
}
8877

8978
try:
90-
response = await self.client.chat.completions.create(**req)
79+
response = await litellm.acompletion(**req)
9180
except Exception as e:
9281
raise RuntimeError(f"CreateChatCompletion failed, err: {e}") from e
9382

9483
return response
9584

9685
async def get_chat_completions_stream(self, messages, tools=None, listener=None):
9786
req = {
98-
"model": self.config.model,
87+
"model": f"{self.config.vendor}/{self.config.model}",
9988
"messages": [
10089
{
10190
"role": "system",
@@ -114,7 +103,7 @@ async def get_chat_completions_stream(self, messages, tools=None, listener=None)
114103
}
115104

116105
try:
117-
response = await self.client.chat.completions.create(**req)
106+
response = await litellm.acompletion(**req)
118107
except Exception as e:
119108
raise RuntimeError(f"CreateChatCompletionStream failed, err: {e}") from e
120109

0 commit comments

Comments
 (0)