diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java index 1945ef09..e02b6776 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/GeneralAssistantFactory.java @@ -19,6 +19,7 @@ package org.apache.bigtop.manager.ai.assistant; import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider; +import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore; import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory; import org.apache.bigtop.manager.ai.core.enums.PlatformType; import org.apache.bigtop.manager.ai.core.enums.SystemPrompt; @@ -70,7 +71,10 @@ public AIAssistant createWithPrompt( case QIANFAN -> QianFanAssistant.builder(); }; AIAssistant aiAssistant = builder.id(id) - .memoryStore((id == null) ? new InMemoryChatMemoryStore() : chatMemoryStore) + .memoryStore( + (id == null) + ? new InMemoryChatMemoryStore() + : ((PersistentChatMemoryStore) chatMemoryStore).clone()) .withConfigProvider(assistantConfig) .build(); diff --git a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java index 91352880..3a905cea 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java +++ b/bigtop-manager-ai/bigtop-manager-ai-assistant/src/main/java/org/apache/bigtop/manager/ai/assistant/store/PersistentChatMemoryStore.java @@ -33,13 +33,14 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.stream.Collectors; public class PersistentChatMemoryStore implements ChatMemoryStore { private final ChatThreadDao chatThreadDao; private final ChatMessageDao chatMessageDao; + private final List systemMessages = new ArrayList<>(); + public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) { this.chatThreadDao = chatThreadDao; this.chatMessageDao = chatMessageDao; @@ -78,19 +79,25 @@ private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatT @Override public List getMessages(Object threadId) { List chatMessages = chatMessageDao.findAllByThreadId((Long) threadId); - if (chatMessages.isEmpty()) { - return new ArrayList<>(); - } else { - return chatMessages.stream() + List allChatMessages = new ArrayList<>(systemMessages); + if (!chatMessages.isEmpty()) { + allChatMessages.addAll(chatMessages.stream() .map(this::convertToChatMessage) .filter(Objects::nonNull) - .collect(Collectors.toList()); + .toList()); } + return allChatMessages; } @Override public void updateMessages(Object threadId, List messages) { - ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId); + ChatMessage newMessage = messages.get(messages.size() - 1); + if (newMessage.type().equals(ChatMessageType.SYSTEM)) { + SystemMessage systemMessage = (SystemMessage) newMessage; + systemMessages.add(systemMessage); + return; + } + ChatMessagePO chatMessagePO = convertToChatMessagePO(newMessage, (Long) threadId); if (chatMessagePO == null) { return; } @@ -103,4 +110,8 @@ public void deleteMessages(Object threadId) { chatMessagePOS.forEach(chatMessage -> chatMessage.setIsDeleted(true)); chatMessageDao.partialUpdateByIds(chatMessagePOS); } + + public PersistentChatMemoryStore clone() { + return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao); + } } diff --git a/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java b/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java index 6ea7fd4f..f178f8f0 100644 --- a/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java +++ b/bigtop-manager-ai/bigtop-manager-ai-dashscope/src/main/java/org/apache/bigtop/manager/ai/dashscope/DashScopeAssistant.java @@ -88,7 +88,7 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista return streamMessage.toString(); } - private void saveMessage(String message, MessageType sender) { + private void addMessage(String message, MessageType sender) { ChatMessage chatMessage; if (sender.equals(MessageType.AI)) { chatMessage = new AiMessage(message); @@ -131,7 +131,7 @@ public void setSystemPrompt(String systemPrompt) { } catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) { throw new RuntimeException(e); } - saveMessage(systemPrompt, MessageType.SYSTEM); + addMessage(systemPrompt, MessageType.SYSTEM); } public static Builder builder() { @@ -140,7 +140,7 @@ public static Builder builder() { @Override public Flux streamAsk(String userMessage) { - saveMessage(userMessage, MessageType.USER); + addMessage(userMessage, MessageType.USER); TextMessageParam textMessageParam = TextMessageParam.builder() .apiKey(dashScopeThreadParam.getApiKey()) .role(Role.USER.getValue()) @@ -174,13 +174,13 @@ public Flux streamAsk(String userMessage) { return message; }) .doOnComplete(() -> { - saveMessage(finalMessage.toString(), MessageType.AI); + addMessage(finalMessage.toString(), MessageType.AI); }); } @Override public String ask(String userMessage) { - saveMessage(userMessage, MessageType.USER); + addMessage(userMessage, MessageType.USER); TextMessageParam textMessageParam = TextMessageParam.builder() .apiKey(dashScopeThreadParam.getApiKey()) .role(Role.USER.getValue()) @@ -244,7 +244,7 @@ public String ask(String userMessage) { ContentText contentText = (ContentText) content; finalMessage.append(contentText.getText().getValue()); } - saveMessage(finalMessage.toString(), MessageType.AI); + addMessage(finalMessage.toString(), MessageType.AI); return finalMessage.toString(); }