Skip to content

Commit 8ff7430

Browse files
authored
Merge pull request #19 from Delfshkrimm/patch-1
Use system_prompt in solver configuration (breaking change)
1 parent 398e718 commit 8ff7430

File tree

4 files changed

+30
-49
lines changed

4 files changed

+30
-49
lines changed

ovos_solver_openai_persona/__init__.py

Lines changed: 9 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,14 @@
1-
from typing import Optional
2-
1+
import warnings
32
from ovos_solver_openai_persona.engines import OpenAIChatCompletionsSolver
43

5-
64
class OpenAIPersonaSolver(OpenAIChatCompletionsSolver):
7-
"""default "Persona" engine"""
8-
9-
def __init__(self, config=None):
10-
# defaults to gpt-3.5-turbo
11-
super().__init__(config=config)
12-
self.default_persona = config.get("persona") or "helpful, creative, clever, and very friendly."
13-
14-
def get_chat_history(self, persona=None):
15-
persona = persona or self.default_persona
16-
initial_prompt = f"You are a helpful assistant. " \
17-
f"You give short and factual answers. " \
18-
f"You are {persona}"
19-
return super().get_chat_history(initial_prompt)
20-
21-
# officially exported Solver methods
22-
def get_spoken_answer(self, query: str,
23-
lang: Optional[str] = None,
24-
units: Optional[str] = None) -> Optional[str]:
25-
"""
26-
Obtain the spoken answer for a given query.
27-
28-
Args:
29-
query (str): The query text.
30-
lang (Optional[str]): Optional language code. Defaults to None.
31-
units (Optional[str]): Optional units for the query. Defaults to None.
32-
33-
Returns:
34-
str: The spoken answer as a text response.
35-
"""
36-
answer = super().get_spoken_answer(query, lang, units)
37-
if not answer or not answer.strip("?") or not answer.strip("_"):
38-
return None
39-
return answer
40-
5+
def __init__(self, *args, **kwargs):
6+
warnings.warn(
7+
"use OpenAIChatCompletionsSolver instead",
8+
DeprecationWarning,
9+
stacklevel=2,
10+
)
11+
super().__init__(*args, **kwargs)
4112

4213
# for ovos-persona
4314
LLAMA_DEMO = {
@@ -51,9 +22,8 @@ def get_spoken_answer(self, query: str,
5122
}
5223
}
5324

54-
5525
if __name__ == "__main__":
56-
bot = OpenAIPersonaSolver(LLAMA_DEMO["ovos-solver-openai-plugin"])
26+
bot = OpenAIChatCompletionsSolver(LLAMA_DEMO["ovos-solver-openai-persona-plugin"])
5727
#for utt in bot.stream_utterances("describe quantum mechanics in simple terms"):
5828
# print(utt)
5929
# Quantum mechanics is a branch of physics that studies the behavior of atoms and particles at the smallest scales.

ovos_solver_openai_persona/dialog_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def __init__(self, name="ovos-dialog-transformer-openai-plugin", priority=10, co
1212
"key": self.config.get("key"),
1313
'api_url': self.config.get('api_url', 'https://api.openai.com/v1'),
1414
"enable_memory": False,
15-
"initial_prompt": "your task is to rewrite text as if it was spoken by a different character"
15+
"system_prompt": self.config.get("system_prompt") or "Your task is to rewrite text as if it was spoken by a different character"
1616
})
1717

1818
def transform(self, dialog: str, context: dict = None) -> Tuple[str, dict]:

ovos_solver_openai_persona/engines.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, config=None,
2424
enable_tx=enable_tx, enable_cache=enable_cache,
2525
internal_lang=internal_lang)
2626
self.api_url = f"{self.config.get('api_url', 'https://api.openai.com/v1')}/completions"
27-
self.engine = self.config.get("model", "text-davinci-002") # "ada" cheaper and faster, "davinci" better
27+
self.engine = self.config.get("model", "gpt-4o-mini")
2828
self.key = self.config.get("key")
2929
if not self.key:
3030
LOG.error("key not set in config")
@@ -107,7 +107,14 @@ def __init__(self, config=None,
107107
self.memory = config.get("enable_memory", True)
108108
self.max_utts = config.get("memory_size", 3)
109109
self.qa_pairs = [] # tuple of q+a
110-
self.initial_prompt = config.get("initial_prompt", "You are a helpful assistant.")
110+
if "persona" in config:
111+
LOG.warning("'persona' config option is deprecated, use 'system_prompt' instead")
112+
if "initial_prompt" in config:
113+
LOG.warning("'initial_prompt' config option is deprecated, use 'system_prompt' instead")
114+
self.system_prompt = config.get("system_prompt") or config.get("initial_prompt")
115+
if not self.system_prompt:
116+
self.system_prompt = "You are a helpful assistant."
117+
LOG.error(f"system prompt not set in config! defaulting to '{self.system_prompt}'")
111118

112119
# OpenAI API integration
113120
def _do_api_request(self, messages):
@@ -179,19 +186,19 @@ def _do_streaming_api_request(self, messages):
179186
continue
180187
yield chunk["choices"][0]["delta"]["content"]
181188

182-
def get_chat_history(self, initial_prompt=None):
189+
def get_chat_history(self, system_prompt=None):
183190
qa = self.qa_pairs[-1 * self.max_utts:]
184-
initial_prompt = initial_prompt or self.initial_prompt or "You are a helpful assistant."
191+
system_prompt = system_prompt or self.system_prompt or "You are a helpful assistant."
185192
messages = [
186-
{"role": "system", "content": initial_prompt},
193+
{"role": "system", "content": system_prompt},
187194
]
188195
for q, a in qa:
189196
messages.append({"role": "user", "content": q})
190197
messages.append({"role": "assistant", "content": a})
191198
return messages
192199

193-
def get_messages(self, utt, initial_prompt=None) -> MessageList:
194-
messages = self.get_chat_history(initial_prompt)
200+
def get_messages(self, utt, system_prompt=None) -> MessageList:
201+
messages = self.get_chat_history(system_prompt)
195202
messages.append({"role": "user", "content": utt})
196203
return messages
197204

@@ -209,6 +216,8 @@ def continue_chat(self, messages: MessageList,
209216
Returns:
210217
Optional[str]: The generated response or None if no response could be generated.
211218
"""
219+
if messages[0]["role"] != "system":
220+
messages = [{"role": "system", "content": self.system_prompt }] + messages
212221
response = self._do_api_request(messages)
213222
answer = post_process_sentence(response)
214223
if not answer or not answer.strip("?") or not answer.strip("_"):
@@ -218,7 +227,7 @@ def continue_chat(self, messages: MessageList,
218227
self.qa_pairs.append((query, answer))
219228
return answer
220229

221-
def stream_chat_utterances(self, messages: List[Dict[str, str]],
230+
def stream_chat_utterances(self, messages: MessageList,
222231
lang: Optional[str] = None,
223232
units: Optional[str] = None) -> Iterable[str]:
224233
"""
@@ -232,6 +241,8 @@ def stream_chat_utterances(self, messages: List[Dict[str, str]],
232241
Returns:
233242
Iterable[str]: An iterable of utterances.
234243
"""
244+
if messages[0]["role"] != "system":
245+
messages = [{"role": "system", "content": self.system_prompt }] + messages
235246
answer = ""
236247
query = messages[-1]["content"]
237248
if self.memory:

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def get_version():
4848

4949

5050
PERSONA_ENTRY_POINT = 'Remote Llama=ovos_solver_openai_persona:LLAMA_DEMO'
51-
PLUGIN_ENTRY_POINT = 'ovos-solver-openai-plugin=ovos_solver_openai_persona:OpenAIPersonaSolver'
51+
PLUGIN_ENTRY_POINT = 'ovos-solver-openai-plugin=ovos_solver_openai_persona.engines:OpenAICompletionsSolver'
5252
DIALOG_PLUGIN_ENTRY_POINT = 'ovos-dialog-transformer-openai-plugin=ovos_solver_openai_persona.dialog_transformers:OpenAIDialogTransformer'
5353
SUMMARIZER_ENTRY_POINT = 'ovos-summarizer-openai-plugin=ovos_solver_openai_persona.summarizer:OpenAISummarizer'
5454

0 commit comments

Comments
 (0)