9
9
from dataclasses import dataclass
10
10
import random
11
11
import requests
12
- from openai import AsyncOpenAI , AsyncAzureOpenAI
12
+ import litellm
13
13
from openai .types .chat .chat_completion import ChatCompletion
14
14
15
15
from ten .async_ten_env import AsyncTenEnv
@@ -46,17 +46,6 @@ class OpenAIChatGPT:
46
46
def __init__ (self , ten_env : AsyncTenEnv , config : OpenAIChatGPTConfig ):
47
47
self .config = config
48
48
ten_env .log_info (f"OpenAIChatGPT initialized with config: { config .api_key } " )
49
- if self .config .vendor == "azure" :
50
- self .client = AsyncAzureOpenAI (
51
- api_key = config .api_key ,
52
- api_version = self .config .azure_api_version ,
53
- azure_endpoint = config .azure_endpoint ,
54
- )
55
- ten_env .log_info (
56
- f"Using Azure OpenAI with endpoint: { config .azure_endpoint } , api_version: { config .azure_api_version } "
57
- )
58
- else :
59
- self .client = AsyncOpenAI (api_key = config .api_key , base_url = config .base_url )
60
49
self .session = requests .Session ()
61
50
if config .proxy_url :
62
51
proxies = {
@@ -69,7 +58,7 @@ def __init__(self, ten_env: AsyncTenEnv, config: OpenAIChatGPTConfig):
69
58
70
59
async def get_chat_completions (self , messages , tools = None ) -> ChatCompletion :
71
60
req = {
72
- "model" : self .config .model ,
61
+ "model" : f" { self .config .vendor } / { self . config . model } " ,
73
62
"messages" : [
74
63
{
75
64
"role" : "system" ,
@@ -87,15 +76,15 @@ async def get_chat_completions(self, messages, tools=None) -> ChatCompletion:
87
76
}
88
77
89
78
try :
90
- response = await self . client . chat . completions . create (** req )
79
+ response = await litellm . acompletion (** req )
91
80
except Exception as e :
92
81
raise RuntimeError (f"CreateChatCompletion failed, err: { e } " ) from e
93
82
94
83
return response
95
84
96
85
async def get_chat_completions_stream (self , messages , tools = None , listener = None ):
97
86
req = {
98
- "model" : self .config .model ,
87
+ "model" : f" { self .config .vendor } / { self . config . model } " ,
99
88
"messages" : [
100
89
{
101
90
"role" : "system" ,
@@ -114,7 +103,7 @@ async def get_chat_completions_stream(self, messages, tools=None, listener=None)
114
103
}
115
104
116
105
try :
117
- response = await self . client . chat . completions . create (** req )
106
+ response = await litellm . acompletion (** req )
118
107
except Exception as e :
119
108
raise RuntimeError (f"CreateChatCompletionStream failed, err: { e } " ) from e
120
109
0 commit comments