Skip to content

Commit a70500d

Browse files
nicolaskrierchedim
authored andcommitted
Use Ollama system role for system message
Signed-off-by: Nicolas Krier <[email protected]>
1 parent 694252e commit a70500d

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,10 @@ Prompt buildRequestPrompt(Prompt prompt) {
439439
OllamaApi.ChatRequest ollamaChatRequest(Prompt prompt, boolean stream) {
440440

441441
List<OllamaApi.Message> ollamaMessages = prompt.getInstructions().stream().map(message -> {
442-
if (message.getMessageType() == MessageType.USER || message.getMessageType() == MessageType.SYSTEM) {
442+
if (message.getMessageType() == MessageType.SYSTEM) {
443+
return List.of(OllamaApi.Message.builder(Role.SYSTEM).content(message.getText()).build());
444+
}
445+
else if (message.getMessageType() == MessageType.USER) {
443446
var messageBuilder = OllamaApi.Message.builder(Role.USER).content(message.getText());
444447
if (message instanceof UserMessage userMessage) {
445448
if (!CollectionUtils.isEmpty(userMessage.getMedia())) {

models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,16 @@
1616

1717
package org.springframework.ai.ollama;
1818

19+
import java.util.List;
1920
import java.util.Map;
2021

2122
import org.junit.jupiter.api.Test;
2223

24+
import org.springframework.ai.chat.messages.AssistantMessage;
25+
import org.springframework.ai.chat.messages.Message;
26+
import org.springframework.ai.chat.messages.SystemMessage;
27+
import org.springframework.ai.chat.messages.ToolResponseMessage;
28+
import org.springframework.ai.chat.messages.UserMessage;
2329
import org.springframework.ai.chat.prompt.ChatOptions;
2430
import org.springframework.ai.chat.prompt.Prompt;
2531
import org.springframework.ai.model.tool.ToolCallingChatOptions;
@@ -36,10 +42,11 @@
3642
* @author Christian Tzolov
3743
* @author Thomas Vitale
3844
* @author Alexandros Pappas
45+
* @author Nicolas Krier
3946
*/
4047
class OllamaChatRequestTests {
4148

42-
OllamaChatModel chatModel = OllamaChatModel.builder()
49+
private final OllamaChatModel chatModel = OllamaChatModel.builder()
4350
.ollamaApi(OllamaApi.builder().build())
4451
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
4552
.retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE)
@@ -167,6 +174,54 @@ public void createRequestWithDefaultOptionsModelOverride() {
167174
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
168175
}
169176

177+
@Test
178+
void createRequestWithAllMessageTypes() {
179+
var prompt = this.chatModel.buildRequestPrompt(new Prompt(createMessagesWithAllMessageTypes()));
180+
181+
var request = this.chatModel.ollamaChatRequest(prompt, false);
182+
183+
assertThat(request.messages()).hasSize(6);
184+
185+
var ollamaSystemMessage = request.messages().get(0);
186+
assertThat(ollamaSystemMessage.role()).isEqualTo(OllamaApi.Message.Role.SYSTEM);
187+
assertThat(ollamaSystemMessage.content()).isEqualTo("Test system message");
188+
189+
var ollamaUserMessage = request.messages().get(1);
190+
assertThat(ollamaUserMessage.role()).isEqualTo(OllamaApi.Message.Role.USER);
191+
assertThat(ollamaUserMessage.content()).isEqualTo("Test user message");
192+
193+
var ollamaToolResponse1 = request.messages().get(2);
194+
assertThat(ollamaToolResponse1.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
195+
assertThat(ollamaToolResponse1.content()).isEqualTo("Test tool response 1");
196+
197+
var ollamaToolResponse2 = request.messages().get(3);
198+
assertThat(ollamaToolResponse2.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
199+
assertThat(ollamaToolResponse2.content()).isEqualTo("Test tool response 2");
200+
201+
var ollamaToolResponse3 = request.messages().get(4);
202+
assertThat(ollamaToolResponse3.role()).isEqualTo(OllamaApi.Message.Role.TOOL);
203+
assertThat(ollamaToolResponse3.content()).isEqualTo("Test tool response 3");
204+
205+
var ollamaAssistantMessage = request.messages().get(5);
206+
assertThat(ollamaAssistantMessage.role()).isEqualTo(OllamaApi.Message.Role.ASSISTANT);
207+
assertThat(ollamaAssistantMessage.content()).isEqualTo("Test assistant message");
208+
}
209+
210+
private static List<Message> createMessagesWithAllMessageTypes() {
211+
var systemMessage = new SystemMessage("Test system message");
212+
var userMessage = new UserMessage("Test user message");
213+
// @formatter:off
214+
var toolResponseMessage = new ToolResponseMessage(List.of(
215+
new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"),
216+
new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"),
217+
new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3"))
218+
);
219+
// @formatter:on
220+
var assistantMessage = new AssistantMessage("Test assistant message");
221+
222+
return List.of(systemMessage, userMessage, toolResponseMessage, assistantMessage);
223+
}
224+
170225
static class TestToolCallback implements ToolCallback {
171226

172227
private final ToolDefinition toolDefinition;

0 commit comments

Comments
 (0)