@@ -24,7 +24,7 @@ def __init__(self, config=None,
24
24
enable_tx = enable_tx , enable_cache = enable_cache ,
25
25
internal_lang = internal_lang )
26
26
self .api_url = f"{ self .config .get ('api_url' , 'https://api.openai.com/v1' )} /completions"
27
- self .engine = self .config .get ("model" , "text-davinci-002" ) # "ada" cheaper and faster, "davinci" better
27
+ self .engine = self .config .get ("model" , "gpt-4o-mini" )
28
28
self .key = self .config .get ("key" )
29
29
if not self .key :
30
30
LOG .error ("key not set in config" )
@@ -107,7 +107,14 @@ def __init__(self, config=None,
107
107
self .memory = config .get ("enable_memory" , True )
108
108
self .max_utts = config .get ("memory_size" , 3 )
109
109
self .qa_pairs = [] # tuple of q+a
110
- self .initial_prompt = config .get ("initial_prompt" , "You are a helpful assistant." )
110
+ if "persona" in config :
111
+ LOG .warning ("'persona' config option is deprecated, use 'system_prompt' instead" )
112
+ if "initial_prompt" in config :
113
+ LOG .warning ("'initial_prompt' config option is deprecated, use 'system_prompt' instead" )
114
+ self .system_prompt = config .get ("system_prompt" ) or config .get ("initial_prompt" )
115
+ if not self .system_prompt :
116
+ self .system_prompt = "You are a helpful assistant."
117
+ LOG .error (f"system prompt not set in config! defaulting to '{ self .system_prompt } '" )
111
118
112
119
# OpenAI API integration
113
120
def _do_api_request (self , messages ):
@@ -179,19 +186,19 @@ def _do_streaming_api_request(self, messages):
179
186
continue
180
187
yield chunk ["choices" ][0 ]["delta" ]["content" ]
181
188
182
- def get_chat_history (self , initial_prompt = None ):
189
+ def get_chat_history (self , system_prompt = None ):
183
190
qa = self .qa_pairs [- 1 * self .max_utts :]
184
- initial_prompt = initial_prompt or self .initial_prompt or "You are a helpful assistant."
191
+ system_prompt = system_prompt or self .system_prompt or "You are a helpful assistant."
185
192
messages = [
186
- {"role" : "system" , "content" : initial_prompt },
193
+ {"role" : "system" , "content" : system_prompt },
187
194
]
188
195
for q , a in qa :
189
196
messages .append ({"role" : "user" , "content" : q })
190
197
messages .append ({"role" : "assistant" , "content" : a })
191
198
return messages
192
199
193
- def get_messages (self , utt , initial_prompt = None ) -> MessageList :
194
- messages = self .get_chat_history (initial_prompt )
200
+ def get_messages (self , utt , system_prompt = None ) -> MessageList :
201
+ messages = self .get_chat_history (system_prompt )
195
202
messages .append ({"role" : "user" , "content" : utt })
196
203
return messages
197
204
@@ -209,6 +216,8 @@ def continue_chat(self, messages: MessageList,
209
216
Returns:
210
217
Optional[str]: The generated response or None if no response could be generated.
211
218
"""
219
+ if messages [0 ]["role" ] != "system" :
220
+ messages = [{"role" : "system" , "content" : self .system_prompt }] + messages
212
221
response = self ._do_api_request (messages )
213
222
answer = post_process_sentence (response )
214
223
if not answer or not answer .strip ("?" ) or not answer .strip ("_" ):
@@ -218,7 +227,7 @@ def continue_chat(self, messages: MessageList,
218
227
self .qa_pairs .append ((query , answer ))
219
228
return answer
220
229
221
- def stream_chat_utterances (self , messages : List [ Dict [ str , str ]] ,
230
+ def stream_chat_utterances (self , messages : MessageList ,
222
231
lang : Optional [str ] = None ,
223
232
units : Optional [str ] = None ) -> Iterable [str ]:
224
233
"""
@@ -232,6 +241,8 @@ def stream_chat_utterances(self, messages: List[Dict[str, str]],
232
241
Returns:
233
242
Iterable[str]: An iterable of utterances.
234
243
"""
244
+ if messages [0 ]["role" ] != "system" :
245
+ messages = [{"role" : "system" , "content" : self .system_prompt }] + messages
235
246
answer = ""
236
247
query = messages [- 1 ]["content" ]
237
248
if self .memory :
0 commit comments