2
2
from typing import Optional , Iterable , List , Dict
3
3
4
4
import requests
5
- from ovos_plugin_manager .templates .solvers import QuestionSolver
6
- from ovos_utils .log import LOG
7
5
8
6
from ovos_plugin_manager .templates .solvers import ChatMessageSolver
7
+ from ovos_plugin_manager .templates .solvers import QuestionSolver
8
+ from ovos_plugin_manager .templates .language import LanguageTranslator , LanguageDetector
9
+ from ovos_utils .log import LOG
9
10
10
11
MessageList = List [Dict [str , str ]] # for typing
11
12
12
- class OpenAICompletionsSolver (QuestionSolver ):
13
- enable_tx = False
14
- priority = 25
15
13
16
- def __init__ (self , config = None ):
17
- super ().__init__ (config )
14
+ class OpenAICompletionsSolver (QuestionSolver ):
15
+ def __init__ (self , config = None ,
16
+ translator : Optional [LanguageTranslator ] = None ,
17
+ detector : Optional [LanguageDetector ] = None ,
18
+ priority : int = 50 ,
19
+ enable_tx : bool = False ,
20
+ enable_cache : bool = False ,
21
+ internal_lang : Optional [str ] = None ):
22
+ super ().__init__ (config = config , translator = translator ,
23
+ detector = detector , priority = priority ,
24
+ enable_tx = enable_tx , enable_cache = enable_cache ,
25
+ internal_lang = internal_lang )
18
26
self .api_url = f"{ self .config .get ('api_url' , 'https://api.openai.com/v1' )} /completions"
19
27
self .engine = self .config .get ("model" , "text-davinci-002" ) # "ada" cheaper and faster, "davinci" better
20
28
self .stop_token = "<|im_end|>"
@@ -78,11 +86,17 @@ def post_process_sentence(text: str) -> str:
78
86
79
87
80
88
class OpenAIChatCompletionsSolver (ChatMessageSolver ):
81
- enable_tx = False
82
- priority = 25
83
-
84
- def __init__ (self , config = None ):
85
- super ().__init__ (config )
89
+ def __init__ (self , config = None ,
90
+ translator : Optional [LanguageTranslator ] = None ,
91
+ detector : Optional [LanguageDetector ] = None ,
92
+ priority : int = 25 ,
93
+ enable_tx : bool = False ,
94
+ enable_cache : bool = False ,
95
+ internal_lang : Optional [str ] = None ):
96
+ super ().__init__ (config = config , translator = translator ,
97
+ detector = detector , priority = priority ,
98
+ enable_tx = enable_tx , enable_cache = enable_cache ,
99
+ internal_lang = internal_lang )
86
100
self .api_url = f"{ self .config .get ('api_url' , 'https://api.openai.com/v1' )} /chat/completions"
87
101
self .engine = self .config .get ("model" , "gpt-4o-mini" ) # "ada" cheaper and faster, "davinci" better
88
102
self .stop_token = "<|im_end|>"
@@ -91,7 +105,7 @@ def __init__(self, config=None):
91
105
LOG .error ("key not set in config" )
92
106
raise ValueError ("key must be set" )
93
107
self .memory = config .get ("enable_memory" , True )
94
- self .max_utts = config .get ("memory_size" , 15 )
108
+ self .max_utts = config .get ("memory_size" , 5 )
95
109
self .qa_pairs = [] # tuple of q+a
96
110
self .initial_prompt = config .get ("initial_prompt" , "You are a helpful assistant." )
97
111
@@ -154,6 +168,9 @@ def _do_streaming_api_request(self, messages):
154
168
if chunk :
155
169
chunk = chunk .decode ("utf-8" )
156
170
chunk = json .loads (chunk .split ("data: " , 1 )[- 1 ])
171
+ if "error" in chunk and "message" in chunk ["error" ]:
172
+ LOG .error ("API returned an error: " + chunk ["error" ]["message" ])
173
+ break
157
174
if chunk ["choices" ][0 ].get ("finish_reason" ):
158
175
break
159
176
if "content" not in chunk ["choices" ][0 ]["delta" ]:
@@ -264,4 +281,3 @@ def get_spoken_answer(self, query: str,
264
281
messages = self .get_messages (query )
265
282
# just for api compat since it's a subclass, shouldn't be directly used
266
283
return self .continue_chat (messages = messages , lang = lang , units = units )
267
-
0 commit comments