Skip to content

Commit

Permalink
add isPersistent flag
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 11, 2024
1 parent 83238f8 commit 0048ed1
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,20 @@ public AIAssistant createWithPrompt(
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
SystemPrompt systemPrompts) {
SystemPrompt systemPrompts,
boolean isPersistent) {
AIAssistant aiAssistant;
if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) {
aiAssistant = OpenAIAssistant.builder()
.id(id)
.memoryStore(messageStoreProvider.getChatMemoryStore())
.memoryStore(messageStoreProvider.getChatMemoryStore(isPersistent))
.withConfigProvider(assistantConfig)
.build();
} else if (Objects.requireNonNull(platformType) == PlatformType.DASH_SCOPE) {
aiAssistant = DashScopeAssistant.builder()
.id(id)
.withConfigProvider(assistantConfig)
.messageRepository(messageStoreProvider.getMessageRepository())
.messageRepository(messageStoreProvider.getMessageRepository(isPersistent))
.build();
} else {
throw new PlatformNotFoundException(platformType.getValue());
Expand All @@ -91,8 +92,9 @@ public AIAssistant createWithPrompt(
}

@Override
public AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id) {
return createWithPrompt(platformType, assistantConfig, id, SystemPrompt.DEFAULT_PROMPT);
public AIAssistant create(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, boolean isPersistent) {
return createWithPrompt(platformType, assistantConfig, id, SystemPrompt.DEFAULT_PROMPT, isPersistent);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.assistant.store.NonPersistentMessageRepository;
import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.ai.assistant.store.PersistentMessageRepository;
import org.apache.bigtop.manager.ai.core.provider.MessageStoreProvider;
Expand Down Expand Up @@ -47,13 +48,16 @@ public PersistentStoreProvider() {
}

@Override
public MessageRepository getMessageRepository() {
public MessageRepository getMessageRepository(boolean isPersistent) {
if (isPersistent || chatThreadDao == null || chatMessageDao == null) {
return new NonPersistentMessageRepository();
}
return new PersistentMessageRepository(chatThreadDao, chatMessageDao);
}

@Override
public ChatMemoryStore getChatMemoryStore() {
if (chatThreadDao == null) {
public ChatMemoryStore getChatMemoryStore(boolean isPersistent) {
if (isPersistent || chatThreadDao == null || chatMessageDao == null) {
return new InMemoryChatMemoryStore();
}
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.repository.MessageRepository;

public class NonPersistentMessageRepository implements MessageRepository {}
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,12 @@ public class PersistentMessageRepository implements MessageRepository {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

private boolean noPersistent() {
return chatThreadDao == null || chatMessageDao == null;
}

public PersistentMessageRepository(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessagePO getChatMessagePO(String message, Long threadId, MessageSender sender) {
if (noPersistent()) {
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
ChatMessagePO chatMessagePO = new ChatMessagePO();
chatMessagePO.setUserId(chatThreadPO.getUserId());
Expand All @@ -53,27 +46,18 @@ private ChatMessagePO getChatMessagePO(String message, Long threadId, MessageSen

@Override
public void saveUserMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.USER);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveAiMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.AI);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveSystemMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.SYSTEM);
chatMessageDao.save(chatMessagePO);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,22 @@ public void init() {
emmit.next(text.charAt(i) + "");
}
}));
when(aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId))
when(aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId, false))
.thenReturn(this.aiAssistant);
when(aiAssistant.getPlatform()).thenReturn(PlatformType.OPENAI);
}

@Test
public void createNew2SimpleChat() {
AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId);
AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId, false);
String ask = aiAssistant.ask("1?");
assertFalse(ask.isEmpty());
System.out.println(ask);
}

@Test
public void createNew2StreamChat() throws InterruptedException {
AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId);
AIAssistant aiAssistant = aiAssistantFactory.create(PlatformType.OPENAI, configProvider, threadId, false);
Flux<String> stringFlux = aiAssistant.streamAsk("stream 1?");
stringFlux.subscribe(
System.out::println,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
public interface AIAssistantFactory {

AIAssistant createWithPrompt(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, SystemPrompt systemPrompt);
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
SystemPrompt systemPrompt,
boolean isPersistent);

AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id);
AIAssistant create(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Object id, boolean isPersistent);

default AIAssistant create(PlatformType platformType, AIAssistantConfigProvider assistantConfig) {
return create(platformType, assistantConfig, null);
default AIAssistant create(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, boolean isPersistent) {
return create(platformType, assistantConfig, null, isPersistent);
}

ToolBox createToolBox(PlatformType platformType);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

public interface MessageStoreProvider {
MessageRepository getMessageRepository();
MessageRepository getMessageRepository(boolean isPersistent);

ChatMemoryStore getChatMemoryStore();
ChatMemoryStore getChatMemoryStore(boolean isPersistent);
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ private AIAssistant buildAIAssistant(
.create(
getPlatformType(platformAuthorizedDTO.getPlatformName()),
getAIAssistantConfig(platformAuthorizedDTO, configs),
threadId);
threadId,
threadId != null);
}

private Boolean testAuthorization(PlatformAuthorizedDTO platformAuthorizedDTO) {
Expand Down

0 comments on commit 0048ed1

Please sign in to comment.