|
16 | 16 |
|
17 | 17 | package org.springframework.ai.ollama; |
18 | 18 |
|
| 19 | +import java.util.List; |
19 | 20 | import java.util.Map; |
20 | 21 |
|
21 | 22 | import org.junit.jupiter.api.Test; |
22 | 23 |
|
| 24 | +import org.springframework.ai.chat.messages.AssistantMessage; |
| 25 | +import org.springframework.ai.chat.messages.Message; |
| 26 | +import org.springframework.ai.chat.messages.SystemMessage; |
| 27 | +import org.springframework.ai.chat.messages.ToolResponseMessage; |
| 28 | +import org.springframework.ai.chat.messages.UserMessage; |
23 | 29 | import org.springframework.ai.chat.prompt.ChatOptions; |
24 | 30 | import org.springframework.ai.chat.prompt.Prompt; |
25 | 31 | import org.springframework.ai.model.tool.ToolCallingChatOptions; |
|
36 | 42 | * @author Christian Tzolov |
37 | 43 | * @author Thomas Vitale |
38 | 44 | * @author Alexandros Pappas |
| 45 | + * @author Nicolas Krier |
39 | 46 | */ |
40 | 47 | class OllamaChatRequestTests { |
41 | 48 |
|
42 | | - OllamaChatModel chatModel = OllamaChatModel.builder() |
| 49 | + private final OllamaChatModel chatModel = OllamaChatModel.builder() |
43 | 50 | .ollamaApi(OllamaApi.builder().build()) |
44 | 51 | .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) |
45 | 52 | .retryTemplate(RetryUtils.DEFAULT_RETRY_TEMPLATE) |
@@ -167,6 +174,54 @@ public void createRequestWithDefaultOptionsModelOverride() { |
167 | 174 | assertThat(request.model()).isEqualTo("PROMPT_MODEL"); |
168 | 175 | } |
169 | 176 |
|
| 177 | + @Test |
| 178 | + void createRequestWithAllMessageTypes() { |
| 179 | + var prompt = this.chatModel.buildRequestPrompt(new Prompt(createMessagesWithAllMessageTypes())); |
| 180 | + |
| 181 | + var request = this.chatModel.ollamaChatRequest(prompt, false); |
| 182 | + |
| 183 | + assertThat(request.messages()).hasSize(6); |
| 184 | + |
| 185 | + var ollamaSystemMessage = request.messages().get(0); |
| 186 | + assertThat(ollamaSystemMessage.role()).isEqualTo(OllamaApi.Message.Role.SYSTEM); |
| 187 | + assertThat(ollamaSystemMessage.content()).isEqualTo("Test system message"); |
| 188 | + |
| 189 | + var ollamaUserMessage = request.messages().get(1); |
| 190 | + assertThat(ollamaUserMessage.role()).isEqualTo(OllamaApi.Message.Role.USER); |
| 191 | + assertThat(ollamaUserMessage.content()).isEqualTo("Test user message"); |
| 192 | + |
| 193 | + var ollamaToolResponse1 = request.messages().get(2); |
| 194 | + assertThat(ollamaToolResponse1.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 195 | + assertThat(ollamaToolResponse1.content()).isEqualTo("Test tool response 1"); |
| 196 | + |
| 197 | + var ollamaToolResponse2 = request.messages().get(3); |
| 198 | + assertThat(ollamaToolResponse2.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 199 | + assertThat(ollamaToolResponse2.content()).isEqualTo("Test tool response 2"); |
| 200 | + |
| 201 | + var ollamaToolResponse3 = request.messages().get(4); |
| 202 | + assertThat(ollamaToolResponse3.role()).isEqualTo(OllamaApi.Message.Role.TOOL); |
| 203 | + assertThat(ollamaToolResponse3.content()).isEqualTo("Test tool response 3"); |
| 204 | + |
| 205 | + var ollamaAssistantMessage = request.messages().get(5); |
| 206 | + assertThat(ollamaAssistantMessage.role()).isEqualTo(OllamaApi.Message.Role.ASSISTANT); |
| 207 | + assertThat(ollamaAssistantMessage.content()).isEqualTo("Test assistant message"); |
| 208 | + } |
| 209 | + |
| 210 | + private static List<Message> createMessagesWithAllMessageTypes() { |
| 211 | + var systemMessage = new SystemMessage("Test system message"); |
| 212 | + var userMessage = new UserMessage("Test user message"); |
| 213 | + // @formatter:off |
| 214 | + var toolResponseMessage = new ToolResponseMessage(List.of( |
| 215 | + new ToolResponseMessage.ToolResponse("tool1", "Tool 1", "Test tool response 1"), |
| 216 | + new ToolResponseMessage.ToolResponse("tool2", "Tool 2", "Test tool response 2"), |
| 217 | + new ToolResponseMessage.ToolResponse("tool3", "Tool 3", "Test tool response 3")) |
| 218 | + ); |
| 219 | + // @formatter:on |
| 220 | + var assistantMessage = new AssistantMessage("Test assistant message"); |
| 221 | + |
| 222 | + return List.of(systemMessage, userMessage, toolResponseMessage, assistantMessage); |
| 223 | + } |
| 224 | + |
170 | 225 | static class TestToolCallback implements ToolCallback { |
171 | 226 |
|
172 | 227 | private final ToolDefinition toolDefinition; |
|
0 commit comments