From 8bf5d44e36f3fbfe5d874b9838f7c5e639014448 Mon Sep 17 00:00:00 2001 From: Linar Abzaltdinov Date: Tue, 22 Apr 2025 11:53:53 +0300 Subject: [PATCH 1/4] chat-memory-cassandra : Fix message order after retrieving from db Signed-off-by: Linar Abzaltdinov --- .../ai/chat/memory/cassandra/CassandraChatMemory.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java index 035aed0ca5e..7fcee99bc45 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/main/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemory.java @@ -18,6 +18,7 @@ import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicLong; @@ -158,6 +159,7 @@ public List get(String sessionId, int lastN) { messages.add(new UserMessage(user)); } } + Collections.reverse(messages); return messages; } From 9915739785a8d8852c6e38bc13991552c5b39d24 Mon Sep 17 00:00:00 2001 From: Linar Abzaltdinov Date: Tue, 22 Apr 2025 11:55:04 +0300 Subject: [PATCH 2/4] chat-memory-cassandra : Added integration tests Signed-off-by: Linar Abzaltdinov --- .../cassandra/CassandraChatMemoryIT.java | 180 +++++++++++++++++- 1 file changed, 174 insertions(+), 6 deletions(-) diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java index c4cf7a8eed4..4725406bd83 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java @@ -16,21 +16,32 @@ package org.springframework.ai.chat.memory.cassandra; -import java.time.Duration; - import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; +import com.datastax.oss.driver.api.core.cql.ResultSet; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.autoconfigure.EnableAutoConfiguration; import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; +import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + +import java.time.Duration; +import java.util.List; +import java.util.UUID; + +import static org.assertj.core.api.Assertions.assertThat; /** * Use `mvn failsafe:integration-test -Dit.test=CassandraChatMemoryIT` @@ -57,6 +68,163 @@ void ensureBeanGetsCreated() { }); } + @ParameterizedTest + @CsvSource({ "Message from assistant,ASSISTANT", "Message from user,USER" }) + void add_shouldInsertSingleMessage(String content, MessageType messageType) { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var message = switch (messageType) { + case ASSISTANT -> new AssistantMessage(content); + case USER -> new UserMessage(content); + default -> throw new IllegalArgumentException("Type not supported: " + messageType); + }; + + chatMemory.add(sessionId, message); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT session_id, message_timestamp, a, u + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var rows = resultSet.all(); + + assertThat(rows.size()).isEqualTo(1); + + var firstRow = rows.get(0); + + assertThat(firstRow.getString("session_id")).isEqualTo(sessionId); + assertThat(firstRow.getInstant("message_timestamp")).isNotNull(); + if (messageType == MessageType.ASSISTANT) { + assertThat(firstRow.getString("a")).isEqualTo(content); + assertThat(firstRow.getString("u")).isNull(); + } + else if (messageType == MessageType.USER) { + assertThat(firstRow.getString("a")).isNull(); + assertThat(firstRow.getString("u")).isEqualTo(content); + } + }); + } + + @Test + void add_shouldInsertMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant"), + new UserMessage("Message from user")); + + chatMemory.add(sessionId, messages); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT session_id, message_timestamp, a, u + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + ORDER BY message_timestamp ASC + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var rows = resultSet.all(); + + assertThat(rows.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = rows.get(i); + + assertThat(result.getString("session_id")).isNotNull(); + assertThat(result.getString("session_id")).isEqualTo(sessionId); + if (message.getMessageType() == MessageType.ASSISTANT) { + assertThat(result.getString("a")).isEqualTo(message.getText()); + assertThat(result.getString("u")).isNull(); + } + else if (message.getMessageType() == MessageType.USER) { + assertThat(result.getString("a")).isNull(); + assertThat(result.getString("u")).isEqualTo(message.getText()); + } + assertThat(result.getInstant("message_timestamp")).isNotNull(); + } + }); + } + + @Test + void get_shouldReturnMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant 1 - " + sessionId), + new AssistantMessage("Message from assistant 2 - " + sessionId), + new UserMessage("Message from user - " + sessionId)); + + chatMemory.add(sessionId, messages); + + var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(messages.size()); + + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.getMessageType()).isEqualTo(message.getMessageType()); + assertThat(result.getText()).isEqualTo(message.getText()); + } + }); + } + + @Test + void get_afterMultipleAdds_shouldReturnMessagesInSameOrder() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var userMessage = new UserMessage("Message from user - " + sessionId); + var assistantMessage = new AssistantMessage("Message from assistant - " + sessionId); + + chatMemory.add(sessionId, userMessage); + chatMemory.add(sessionId, assistantMessage); + + var results = chatMemory.get(sessionId, Integer.MAX_VALUE); + + assertThat(results.size()).isEqualTo(2); + + var messages = List.of(userMessage, assistantMessage); + for (var i = 0; i < messages.size(); i++) { + var message = messages.get(i); + var result = results.get(i); + + assertThat(result.getMessageType()).isEqualTo(message.getMessageType()); + assertThat(result.getText()).isEqualTo(message.getText()); + } + }); + } + + @Test + void clear_shouldDeleteMessages() { + this.contextRunner.run(context -> { + var chatMemory = context.getBean(ChatMemory.class); + var sessionId = UUID.randomUUID().toString(); + var messages = List.of(new AssistantMessage("Message from assistant - " + sessionId), + new UserMessage("Message from user - " + sessionId)); + + chatMemory.add(sessionId, messages); + + chatMemory.clear(sessionId); + + var cqlSession = context.getBean(CqlSession.class); + var query = """ + SELECT COUNT(*) + FROM test_springframework.ai_chat_memory + WHERE session_id = ? + """; + ResultSet resultSet = cqlSession.execute(query, sessionId); + var count = resultSet.all().get(0).getLong(0); + + assertThat(count).isZero(); + }); + } + @SpringBootConfiguration @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) public static class TestApplication { From 20036678031b510f5f814a942976dbcb85cfd4a1 Mon Sep 17 00:00:00 2001 From: Linar Abzaltdinov Date: Tue, 22 Apr 2025 13:41:33 +0300 Subject: [PATCH 3/4] chat-memory-cassandra : replaced deprecated CassandraContainer test-container by actual one Signed-off-by: Linar Abzaltdinov --- .../autoconfigure/CassandraChatMemoryAutoConfigurationIT.java | 2 +- .../ai/chat/memory/cassandra/CassandraChatMemoryIT.java | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java index 76bac2e38a7..cdc5fea23c9 100644 --- a/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java +++ b/auto-configurations/models/chat/memory/spring-ai-autoconfigure-model-chat-memory-cassandra/src/test/java/org/springframework/ai/model/chat/memory/cassandra/autoconfigure/CassandraChatMemoryAutoConfigurationIT.java @@ -21,7 +21,7 @@ import com.datastax.driver.core.utils.UUIDs; import org.junit.jupiter.api.Test; -import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.cassandra.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; import org.testcontainers.utility.DockerImageName; diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java index 4725406bd83..f0f9d3f7967 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java @@ -33,7 +33,7 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.testcontainers.containers.CassandraContainer; +import org.testcontainers.cassandra.CassandraContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -54,7 +54,7 @@ class CassandraChatMemoryIT { @Container - static CassandraContainer cassandraContainer = new CassandraContainer<>(CassandraImage.DEFAULT_IMAGE); + static CassandraContainer cassandraContainer = new CassandraContainer(CassandraImage.DEFAULT_IMAGE); private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() .withUserConfiguration(CassandraChatMemoryIT.TestApplication.class); From 82167ac3c87de1b81028f5060e22e17f566247e6 Mon Sep 17 00:00:00 2001 From: Linar Abzaltdinov Date: Wed, 23 Apr 2025 14:38:07 +0300 Subject: [PATCH 4/4] chat-memory-cassandra : reordered imports in CassandraChatMemoryIT Signed-off-by: Linar Abzaltdinov --- .../memory/cassandra/CassandraChatMemoryIT.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java index f0f9d3f7967..8fbf0059cc7 100644 --- a/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java +++ b/memory/spring-ai-model-chat-memory-cassandra/src/test/java/org/springframework/ai/chat/memory/cassandra/CassandraChatMemoryIT.java @@ -16,6 +16,10 @@ package org.springframework.ai.chat.memory.cassandra; +import java.time.Duration; +import java.util.List; +import java.util.UUID; + import com.datastax.oss.driver.api.core.CqlSession; import com.datastax.oss.driver.api.core.CqlSessionBuilder; import com.datastax.oss.driver.api.core.cql.ResultSet; @@ -23,6 +27,10 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.CsvSource; +import org.testcontainers.cassandra.CassandraContainer; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; + import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.Message; @@ -33,13 +41,6 @@ import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.testcontainers.cassandra.CassandraContainer; -import org.testcontainers.junit.jupiter.Container; -import org.testcontainers.junit.jupiter.Testcontainers; - -import java.time.Duration; -import java.util.List; -import java.util.UUID; import static org.assertj.core.api.Assertions.assertThat;