Skip to content

Commit e1c0fc3

Browse files
authored
fix: chat roles for model responses in chat generators (#1030)
1 parent 81f66c8 commit e1c0fc3

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

integrations/amazon_bedrock/tests/test_chat_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
200200
"""
201201
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
202202
"""
203-
messages = [ChatMessage.from_system("What is the biggest city in United States?")]
203+
messages = [ChatMessage.from_user("What is the biggest city in United States?")]
204204

205205
# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
206206
max_length_generated_text = 3

integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,14 +230,14 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
230230
raise ValueError(msg)
231231

232232
def _message_to_part(self, message: ChatMessage) -> Part:
233-
if message.role == ChatRole.SYSTEM and message.name:
233+
if message.role == ChatRole.ASSISTANT and message.name:
234234
p = Part()
235235
p.function_call.name = message.name
236236
p.function_call.args = {}
237237
for k, v in message.content.items():
238238
p.function_call.args[k] = v
239239
return p
240-
elif message.role == ChatRole.SYSTEM:
240+
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
241241
p = Part()
242242
p.text = message.content
243243
return p
@@ -250,13 +250,13 @@ def _message_to_part(self, message: ChatMessage) -> Part:
250250
return self._convert_part(message.content)
251251

252252
def _message_to_content(self, message: ChatMessage) -> Content:
253-
if message.role == ChatRole.SYSTEM and message.name:
253+
if message.role == ChatRole.ASSISTANT and message.name:
254254
part = Part()
255255
part.function_call.name = message.name
256256
part.function_call.args = {}
257257
for k, v in message.content.items():
258258
part.function_call.args[k] = v
259-
elif message.role == ChatRole.SYSTEM:
259+
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
260260
part = Part()
261261
part.text = message.content
262262
elif message.role == ChatRole.FUNCTION:
@@ -315,12 +315,12 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess
315315
for candidate in response_body.candidates:
316316
for part in candidate.content.parts:
317317
if part.text != "":
318-
replies.append(ChatMessage.from_system(part.text))
318+
replies.append(ChatMessage.from_assistant(part.text))
319319
elif part.function_call is not None:
320320
replies.append(
321321
ChatMessage(
322322
content=dict(part.function_call.args.items()),
323-
role=ChatRole.SYSTEM,
323+
role=ChatRole.ASSISTANT,
324324
name=part.function_call.name,
325325
)
326326
)
@@ -343,4 +343,4 @@ def _get_stream_response(
343343
responses.append(content)
344344

345345
combined_response = "".join(responses).lstrip()
346-
return [ChatMessage.from_system(content=combined_response)]
346+
return [ChatMessage.from_assistant(content=combined_response)]

integrations/google_ai/tests/generators/chat/test_chat_gemini.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,8 +256,9 @@ def get_current_weather(location: str, unit: str = "celsius"): # noqa: ARG001
256256
def test_past_conversation():
257257
gemini_chat = GoogleAIGeminiChatGenerator(model="gemini-pro")
258258
messages = [
259+
ChatMessage.from_system(content="You are a knowledageable mathematician."),
259260
ChatMessage.from_user(content="What is 2+2?"),
260-
ChatMessage.from_system(content="It's an arithmetic operation."),
261+
ChatMessage.from_assistant(content="It's an arithmetic operation."),
261262
ChatMessage.from_user(content="Yeah, but what's the result?"),
262263
]
263264
res = gemini_chat.run(messages=messages)

integrations/google_vertex/src/haystack_integrations/components/generators/google_vertex/chat/gemini.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,24 +161,24 @@ def _convert_part(self, part: Union[str, ByteStream, Part]) -> Part:
161161
raise ValueError(msg)
162162

163163
def _message_to_part(self, message: ChatMessage) -> Part:
164-
if message.role == ChatRole.SYSTEM and message.name:
164+
if message.role == ChatRole.ASSISTANT and message.name:
165165
p = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
166166
for k, v in message.content.items():
167167
p.function_call.args[k] = v
168168
return p
169-
elif message.role == ChatRole.SYSTEM:
169+
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
170170
return Part.from_text(message.content)
171171
elif message.role == ChatRole.FUNCTION:
172172
return Part.from_function_response(name=message.name, response=message.content)
173173
elif message.role == ChatRole.USER:
174174
return self._convert_part(message.content)
175175

176176
def _message_to_content(self, message: ChatMessage) -> Content:
177-
if message.role == ChatRole.SYSTEM and message.name:
177+
if message.role == ChatRole.ASSISTANT and message.name:
178178
part = Part.from_dict({"function_call": {"name": message.name, "args": {}}})
179179
for k, v in message.content.items():
180180
part.function_call.args[k] = v
181-
elif message.role == ChatRole.SYSTEM:
181+
elif message.role in {ChatRole.SYSTEM, ChatRole.ASSISTANT}:
182182
part = Part.from_text(message.content)
183183
elif message.role == ChatRole.FUNCTION:
184184
part = Part.from_function_response(name=message.name, response=message.content)
@@ -233,12 +233,12 @@ def _get_response(self, response_body: GenerationResponse) -> List[ChatMessage]:
233233
for candidate in response_body.candidates:
234234
for part in candidate.content.parts:
235235
if part._raw_part.text != "":
236-
replies.append(ChatMessage.from_system(part.text))
236+
replies.append(ChatMessage.from_assistant(part.text))
237237
elif part.function_call is not None:
238238
replies.append(
239239
ChatMessage(
240240
content=dict(part.function_call.args.items()),
241-
role=ChatRole.SYSTEM,
241+
role=ChatRole.ASSISTANT,
242242
name=part.function_call.name,
243243
)
244244
)
@@ -261,4 +261,4 @@ def _get_stream_response(
261261
responses.append(streaming_chunk.content)
262262

263263
combined_response = "".join(responses).lstrip()
264-
return [ChatMessage.from_system(content=combined_response)]
264+
return [ChatMessage.from_assistant(content=combined_response)]

integrations/ollama/examples/chat_generator_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
messages = [
1313
ChatMessage.from_user("What's Natural Language Processing?"),
14-
ChatMessage.from_system(
14+
ChatMessage.from_assistant(
1515
"Natural Language Processing (NLP) is a field of computer science and artificial "
1616
"intelligence concerned with the interaction between computers and human language"
1717
),

0 commit comments

Comments
 (0)