Skip to content

Commit

Permalink
BIGTOP-4213: Add LLM platform DashScope (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq authored Sep 10, 2024
1 parent 8dd681d commit 83238f8
Show file tree
Hide file tree
Showing 26 changed files with 868 additions and 170 deletions.
4 changes: 4 additions & 0 deletions bigtop-manager-ai/bigtop-manager-ai-assistant/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-ai-dashscope</artifactId>
</dependency>
<dependency>
<groupId>org.apache.bigtop</groupId>
<artifactId>bigtop-manager-dao</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,44 +19,44 @@
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.provider.PersistentStoreProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.PlatformNotFoundException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.ToolBox;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.MessageStoreProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.apache.commons.lang3.NotImplementedException;

import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import java.util.Objects;

public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStore chatMemoryStore;
private final MessageStoreProvider messageStoreProvider;

public GeneralAssistantFactory() {
this(new LocSystemPromptProvider(), new InMemoryChatMemoryStore());
this(new LocSystemPromptProvider(), new PersistentStoreProvider());
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider) {
this(systemPromptProvider, new InMemoryChatMemoryStore());
this(systemPromptProvider, new PersistentStoreProvider());
}

public GeneralAssistantFactory(ChatMemoryStore chatMemoryStore) {
this(new LocSystemPromptProvider(), chatMemoryStore);
public GeneralAssistantFactory(MessageStoreProvider messageStoreProvider) {
this(new LocSystemPromptProvider(), messageStoreProvider);
}

public GeneralAssistantFactory(SystemPromptProvider systemPromptProvider, ChatMemoryStore chatMemoryStore) {
public GeneralAssistantFactory(
SystemPromptProvider systemPromptProvider, MessageStoreProvider messageStoreProvider) {
this.systemPromptProvider = systemPromptProvider;
this.chatMemoryStore = chatMemoryStore;
this.messageStoreProvider = messageStoreProvider;
}

@Override
Expand All @@ -69,14 +69,19 @@ public AIAssistant createWithPrompt(
if (Objects.requireNonNull(platformType) == PlatformType.OPENAI) {
aiAssistant = OpenAIAssistant.builder()
.id(id)
.memoryStore(chatMemoryStore)
.memoryStore(messageStoreProvider.getChatMemoryStore())
.withConfigProvider(assistantConfig)
.build();
} else if (Objects.requireNonNull(platformType) == PlatformType.DASH_SCOPE) {
aiAssistant = DashScopeAssistant.builder()
.id(id)
.withConfigProvider(assistantConfig)
.messageRepository(messageStoreProvider.getMessageRepository())
.build();
} else {
throw new PlatformNotFoundException(platformType.getValue());
}

SystemMessage systemPrompt = systemPromptProvider.getSystemPrompt(systemPrompts);
String systemPrompt = systemPromptProvider.getSystemMessage(systemPrompts);
aiAssistant.setSystemPrompt(systemPrompt);
String locale = assistantConfig.getLanguage();
if (locale != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ public Builder addConfig(String key, String value) {
}

public Builder addConfigs(Map<String, String> configMap) {
configs.putAll(configMap);
if (configMap != null) {
configs.putAll(configMap);
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import org.springframework.util.ResourceUtils;

import dev.langchain4j.data.message.SystemMessage;
import lombok.extern.slf4j.Slf4j;

import java.io.File;
Expand All @@ -36,7 +35,7 @@
public class LocSystemPromptProvider implements SystemPromptProvider {

@Override
public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
public String getSystemMessage(SystemPrompt systemPrompt) {
if (systemPrompt == SystemPrompt.DEFAULT_PROMPT) {
systemPrompt = SystemPrompt.BIGDATA_PROFESSOR;
}
Expand All @@ -45,8 +44,8 @@ public SystemMessage getSystemPrompt(SystemPrompt systemPrompt) {
}

@Override
public SystemMessage getSystemPrompt() {
return getSystemPrompt(SystemPrompt.DEFAULT_PROMPT);
public String getSystemMessage() {
return getSystemMessage(SystemPrompt.DEFAULT_PROMPT);
}

private String loadTextFromFile(String fileName) {
Expand All @@ -64,23 +63,23 @@ private String loadTextFromFile(String fileName) {
}
}

private SystemMessage loadPromptFromFile(String fileName) {
private String loadPromptFromFile(String fileName) {
final String filePath = fileName + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return SystemMessage.from("You are a helpful assistant.");
return "You are a helpful assistant.";
} else {
return SystemMessage.from(text);
return text;
}
}

public SystemMessage getLanguagePrompt(String locale) {
public String getLanguagePrompt(String locale) {
final String filePath = SystemPrompt.LANGUAGE_PROMPT.getValue() + '-' + locale + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return SystemMessage.from("Answer in " + locale);
return "Answer in " + locale;
} else {
return SystemMessage.from(text);
return text;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.ai.assistant.store.PersistentMessageRepository;
import org.apache.bigtop.manager.ai.core.provider.MessageStoreProvider;
import org.apache.bigtop.manager.ai.core.repository.MessageRepository;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class PersistentStoreProvider implements MessageStoreProvider {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

public PersistentStoreProvider() {
chatMessageDao = null;
chatThreadDao = null;
}

@Override
public MessageRepository getMessageRepository() {
return new PersistentMessageRepository(chatThreadDao, chatMessageDao);
}

@Override
public ChatMemoryStore getChatMemoryStore() {
if (chatThreadDao == null) {
return new InMemoryChatMemoryStore();
}
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.repository.MessageRepository;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

public class PersistentMessageRepository implements MessageRepository {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

private boolean noPersistent() {
return chatThreadDao == null || chatMessageDao == null;
}

public PersistentMessageRepository(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

private ChatMessagePO getChatMessagePO(String message, Long threadId, MessageSender sender) {
if (noPersistent()) {
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
ChatMessagePO chatMessagePO = new ChatMessagePO();
chatMessagePO.setUserId(chatThreadPO.getUserId());
chatMessagePO.setThreadId(threadId);
chatMessagePO.setSender(sender.getValue());
chatMessagePO.setMessage(message);
return chatMessagePO;
}

@Override
public void saveUserMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.USER);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveAiMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.AI);
chatMessageDao.save(chatMessagePO);
}

@Override
public void saveSystemMessage(String message, Long threadId) {
if (noPersistent()) {
return;
}
ChatMessagePO chatMessagePO = getChatMessagePO(message, threadId, MessageSender.SYSTEM);
chatMessageDao.save(chatMessagePO);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@

import org.junit.jupiter.api.Test;

import dev.langchain4j.data.message.SystemMessage;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

Expand All @@ -35,12 +33,12 @@ public class SystemPromptProviderTests {

@Test
public void loadSystemPromptByIdTest() {
SystemMessage systemPrompt1 = systemPromptProvider.getSystemPrompt(SystemPrompt.BIGDATA_PROFESSOR);
assertFalse(systemPrompt1.text().isEmpty());
String systemPrompt1 = systemPromptProvider.getSystemMessage(SystemPrompt.BIGDATA_PROFESSOR);
assertFalse(systemPrompt1.isEmpty());

SystemMessage systemPrompt2 = systemPromptProvider.getSystemPrompt();
assertFalse(systemPrompt2.text().isEmpty());
String systemPrompt2 = systemPromptProvider.getSystemMessage();
assertFalse(systemPrompt2.isEmpty());

assertEquals(systemPrompt1.text(), systemPrompt2.text());
assertEquals(systemPrompt1, systemPrompt2);
}
}
Loading

0 comments on commit 83238f8

Please sign in to comment.