@@ -36,6 +36,7 @@ def __init__(
36
36
self ,
37
37
model_name : str ,
38
38
model_params : Optional [dict [str , Any ]] = None ,
39
+ system_instruction : Optional [str ] = None ,
39
40
):
40
41
"""
41
42
Base class for OpenAI LLM.
@@ -54,7 +55,7 @@ def __init__(
54
55
"Please install it with `pip install openai`."
55
56
)
56
57
self .openai = openai
57
- super ().__init__ (model_name , model_params )
58
+ super ().__init__ (model_name , model_params , system_instruction )
58
59
59
60
def get_messages (
60
61
self ,
@@ -64,6 +65,32 @@ def get_messages(
64
65
{"role" : "system" , "content" : input },
65
66
]
66
67
68
+ def get_conversation_history (
69
+ self ,
70
+ input : str ,
71
+ chat_history : list [str ],
72
+ ) -> Iterable [ChatCompletionMessageParam ]:
73
+ messages = [{"role" : "system" , "content" : self .system_instruction }]
74
+ for i , message in enumerate (chat_history ):
75
+ if i % 2 == 0 :
76
+ messages .append ({"role" : "user" , "content" : message })
77
+ else :
78
+ messages .append ({"role" : "assistant" , "content" : message })
79
+ messages .append ({"role" : "user" , "content" : input })
80
+ return messages
81
+
82
+ def chat (self , input : str , chat_history : list [str ]) -> LLMResponse :
83
+ try :
84
+ response = self .client .chat .completions .create (
85
+ messages = self .get_conversation_history (input , chat_history ),
86
+ model = self .model_name ,
87
+ ** self .model_params ,
88
+ )
89
+ content = response .choices [0 ].message .content or ""
90
+ return LLMResponse (content = content )
91
+ except self .openai .OpenAIError as e :
92
+ raise LLMGenerationError (e )
93
+
67
94
def invoke (self , input : str ) -> LLMResponse :
68
95
"""Sends a text input to the OpenAI chat completion model
69
96
and returns the response's content.
@@ -118,6 +145,7 @@ def __init__(
118
145
self ,
119
146
model_name : str ,
120
147
model_params : Optional [dict [str , Any ]] = None ,
148
+ system_instruction : Optional [str ] = None ,
121
149
** kwargs : Any ,
122
150
):
123
151
"""OpenAI LLM
@@ -129,7 +157,7 @@ def __init__(
129
157
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
130
158
kwargs: All other parameters will be passed to the openai.OpenAI init.
131
159
"""
132
- super ().__init__ (model_name , model_params )
160
+ super ().__init__ (model_name , model_params , system_instruction )
133
161
self .client = self .openai .OpenAI (** kwargs )
134
162
self .async_client = self .openai .AsyncOpenAI (** kwargs )
135
163
@@ -139,6 +167,7 @@ def __init__(
139
167
self ,
140
168
model_name : str ,
141
169
model_params : Optional [dict [str , Any ]] = None ,
170
+ system_instruction : Optional [str ] = None ,
142
171
** kwargs : Any ,
143
172
):
144
173
"""Azure OpenAI LLM. Use this class when using an OpenAI model
@@ -149,6 +178,6 @@ def __init__(
149
178
model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
150
179
kwargs: All other parameters will be passed to the openai.OpenAI init.
151
180
"""
152
- super ().__init__ (model_name , model_params )
181
+ super ().__init__ (model_name , model_params , system_instruction )
153
182
self .client = self .openai .AzureOpenAI (** kwargs )
154
183
self .async_client = self .openai .AsyncAzureOpenAI (** kwargs )
0 commit comments