Skip to content

Commit

Permalink
BIGTOP-4313: Adjust the code of AI module (apache#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq authored Jan 3, 2025
1 parent 16fa4fd commit 5529204
Show file tree
Hide file tree
Showing 18 changed files with 420 additions and 497 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,74 +18,87 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.provider.LocSystemPromptProvider;
import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.assistant.config.GeneralAssistantConfig;
import org.apache.bigtop.manager.ai.assistant.provider.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.AbstractAIAssistantFactory;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.exception.AssistantConfigNotSetException;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.dashscope.DashScopeAssistant;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;
import org.apache.bigtop.manager.ai.qianfan.QianFanAssistant;

import org.springframework.stereotype.Component;

import dev.langchain4j.service.tool.ToolProvider;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;
import java.util.ArrayList;
import java.util.List;

@Component
public class GeneralAssistantFactory extends AbstractAIAssistantFactory {

private final SystemPromptProvider systemPromptProvider;
private final ChatMemoryStoreProvider chatMemoryStoreProvider;
@Resource
private SystemPromptProvider systemPromptProvider;

@Resource
private ChatMemoryStoreProvider chatMemoryStoreProvider;

public GeneralAssistantFactory(ChatMemoryStoreProvider chatMemoryStoreProvider) {
this(new LocSystemPromptProvider(), chatMemoryStoreProvider);
private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt systemPrompt, String locale) {
List<String> systemPrompts = new ArrayList<>();
if (systemPrompt != null) {
systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt));
}
if (locale != null) {
systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale));
}
builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts));
}

public GeneralAssistantFactory(
SystemPromptProvider systemPromptProvider, ChatMemoryStoreProvider chatMemoryStoreProvider) {
this.systemPromptProvider = systemPromptProvider;
this.chatMemoryStoreProvider = chatMemoryStoreProvider;
private AIAssistant.Builder initializeBuilder(PlatformType platformType) {
return switch (platformType) {
case OPENAI -> OpenAIAssistant.builder();
case DASH_SCOPE -> DashScopeAssistant.builder();
case QIANFAN -> QianFanAssistant.builder();
};
}

@Override
public AIAssistant createWithPrompt(
PlatformType platformType,
AIAssistantConfigProvider assistantConfig,
Object id,
ToolProvider toolProvider,
SystemPrompt systemPrompt) {
AIAssistant.Builder builder =
switch (platformType) {
case OPENAI -> OpenAIAssistant.builder();
case DASH_SCOPE -> DashScopeAssistant.builder();
case QIANFAN -> QianFanAssistant.builder();
};
builder = builder.id(id)
.memoryStore(
(id == null)
? new InMemoryChatMemoryStore()
: chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfigProvider(assistantConfig)
.withToolProvider(toolProvider);

List<String> systemPrompts = new java.util.ArrayList<>();
systemPrompts.add(systemPromptProvider.getSystemMessage(systemPrompt));
String locale = assistantConfig.getLanguage();
if (locale != null) {
systemPrompts.add(systemPromptProvider.getLanguagePrompt(locale));
AIAssistantConfig config, ToolProvider toolProvider, SystemPrompt systemPrompt) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
Object id = generalAssistantConfig.getId();
if (id == null) {
throw new AssistantConfigNotSetException("ID");
}

builder.withSystemPrompt(systemPromptProvider.getSystemMessages(systemPrompts));
AIAssistant.Builder builder = initializeBuilder(platformType);
builder.id(id)
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);

configureSystemPrompt(builder, systemPrompt, generalAssistantConfig.getLanguage());

return builder.build();
}

@Override
public AIAssistant createAiService(
PlatformType platformType, AIAssistantConfigProvider assistantConfig, Long id, ToolProvider toolProvider) {
return createWithPrompt(platformType, assistantConfig, id, toolProvider, SystemPrompt.DEFAULT_PROMPT);
public AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProvider) {
GeneralAssistantConfig generalAssistantConfig = (GeneralAssistantConfig) config;
PlatformType platformType = generalAssistantConfig.getPlatformType();
AIAssistant.Builder builder = initializeBuilder(platformType);

builder.id(null)
.memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore())
.withConfig(generalAssistantConfig)
.withToolProvider(toolProvider);

return builder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,34 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.provider;
package org.apache.bigtop.manager.ai.assistant.config;

import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;

import lombok.Getter;

import java.util.HashMap;
import java.util.Map;
import java.util.Objects;

public class AIAssistantConfig implements AIAssistantConfigProvider {
@Getter
public class GeneralAssistantConfig implements AIAssistantConfig {

/**
* Model name for platform that we want to use
*/
private final Long id;
private final String model;

/**
* Credentials for different platforms
*/
private final Map<String, String> credentials;

private final String language;
/**
* Platform extra configs are put here
*/
private final PlatformType platformType;
private final Map<String, String> credentials;
private final Map<String, String> configs;

private AIAssistantConfig(
String model, Map<String, String> credentials, String language, Map<String, String> configMap) {
this.model = model;
this.credentials = credentials;
this.language = language;
this.configs = configMap;
private GeneralAssistantConfig(Builder builder) {
this.model = Objects.requireNonNull(builder.model);
this.credentials = Objects.requireNonNull(builder.credentials);
this.platformType = Objects.requireNonNull(builder.platformType);
this.language = builder.language;
this.id = builder.id;
this.configs = builder.configs;
}

public static Builder builder() {
Expand All @@ -68,26 +65,29 @@ public Map<String, String> getConfigs() {
return configs;
}

@Override
public String getLanguage() {
return language;
}

public static class Builder {
private Long id;
private String model;
private String language;

private PlatformType platformType;
private final Map<String, String> credentials = new HashMap<>();

private final Map<String, String> configs = new HashMap<>();

public Builder() {}

public Builder setModel(String model) {
this.model = model;
return this;
}

public Builder setPlatformType(PlatformType platformType) {
this.platformType = platformType;
return this;
}

public Builder setId(Long id) {
this.id = id;
return this;
}

public Builder addCredential(String key, String value) {
credentials.put(key, value);
return this;
Expand Down Expand Up @@ -115,8 +115,8 @@ public Builder addConfigs(Map<String, String> configMap) {
return this;
}

public AIAssistantConfig build() {
return new AIAssistantConfig(model, credentials, language, configs);
public GeneralAssistantConfig build() {
return new GeneralAssistantConfig(this);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,32 @@
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.ai.assistant.store;
package org.apache.bigtop.manager.ai.assistant.provider;

import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
import org.apache.bigtop.manager.dao.repository.ChatThreadDao;

import org.springframework.stereotype.Component;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;
import dev.langchain4j.store.memory.chat.InMemoryChatMemoryStore;

import jakarta.annotation.Resource;

@Component
public class ChatMemoryStoreProvider {
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;
@Resource
private ChatThreadDao chatThreadDao;

public ChatMemoryStoreProvider(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
this.chatMessageDao = chatMessageDao;
}

public ChatMemoryStoreProvider() {
this(null, null);
}
@Resource
private ChatMessageDao chatMessageDao;

public ChatMemoryStore createPersistentChatMemoryStore() {
if (chatThreadDao == null || chatMessageDao == null) {
return new InMemoryChatMemoryStore();
}
return new PersistentChatMemoryStore(chatThreadDao, chatMessageDao);
}

public ChatMemoryStore createInMemoryChatMemoryStore() {
return new InMemoryChatMemoryStore();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.bigtop.manager.ai.core.enums.SystemPrompt;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;

import org.springframework.stereotype.Component;
import org.springframework.util.ResourceUtils;

import lombok.extern.slf4j.Slf4j;
Expand All @@ -33,6 +34,7 @@
import java.util.Objects;

@Slf4j
@Component
public class LocSystemPromptProvider implements SystemPromptProvider {

@Override
Expand Down Expand Up @@ -67,21 +69,14 @@ private String loadTextFromFile(String fileName) {
private String loadPromptFromFile(String fileName) {
final String filePath = fileName + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return "You are a helpful assistant.";
} else {
return text;
}
return Objects.requireNonNullElse(text, "You are a helpful assistant.");
}

@Override
public String getLanguagePrompt(String locale) {
final String filePath = SystemPrompt.LANGUAGE_PROMPT.getValue() + '-' + locale + ".st";
String text = loadTextFromFile(filePath);
if (text == null) {
return "Answer in " + locale;
} else {
return text;
}
return Objects.requireNonNullElseGet(text, () -> "Answer in " + locale);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
public class PersistentChatMemoryStore implements ChatMemoryStore {

private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();
protected final ChatThreadDao chatThreadDao;
protected final ChatMessageDao chatMessageDao;
private final ChatThreadDao chatThreadDao;
private final ChatMessageDao chatMessageDao;

public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao chatMessageDao) {
this.chatThreadDao = chatThreadDao;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
*/
package org.apache.bigtop.manager.ai.assistant;

import org.apache.bigtop.manager.ai.assistant.store.ChatMemoryStoreProvider;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.config.AIAssistantConfig;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;
import org.apache.bigtop.manager.ai.core.provider.SystemPromptProvider;
import org.apache.bigtop.manager.ai.openai.OpenAIAssistant;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.MockedStatic;
import org.mockito.MockitoAnnotations;
Expand All @@ -42,42 +38,35 @@
class GeneralAssistantFactoryTest {

@Mock
private SystemPromptProvider systemPromptProvider;
private AIAssistantConfig assistantConfigProvider;

@Mock
private AIAssistantConfigProvider assistantConfigProvider;

@InjectMocks
private GeneralAssistantFactory generalAssistantFactory;

@BeforeEach
void setUp() {
MockitoAnnotations.openMocks(this);
generalAssistantFactory = new GeneralAssistantFactory(systemPromptProvider, new ChatMemoryStoreProvider());
Map<String, String> credentials = Map.of("apiKey", "123456");
when(assistantConfigProvider.getModel()).thenReturn("model");
when(assistantConfigProvider.getCredentials()).thenReturn(credentials);
when(assistantConfigProvider.getConfigs()).thenReturn(null);
when(assistantConfigProvider.getLanguage()).thenReturn("en");
}

@Test
void testCreateAIAssistant() {
AIAssistant.Builder mockBuilder = mock(OpenAIAssistant.Builder.class);
when(mockBuilder.id(any())).thenReturn(mockBuilder);
when(mockBuilder.memoryStore(any())).thenReturn(mockBuilder);
when(mockBuilder.withConfigProvider(any())).thenReturn(mockBuilder);
when(mockBuilder.withConfig(any())).thenReturn(mockBuilder);
when(mockBuilder.withToolProvider(any())).thenReturn(mockBuilder);
when(mockBuilder.withSystemPrompt(any())).thenReturn(mockBuilder);
when(mockBuilder.build()).thenReturn(mock(AIAssistant.class));

try (MockedStatic<OpenAIAssistant> openAIAssistantMockedStatic = mockStatic(OpenAIAssistant.class)) {
openAIAssistantMockedStatic.when(OpenAIAssistant::builder).thenReturn(mockBuilder);

PlatformType platformType = PlatformType.OPENAI;
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory = new GeneralAssistantFactory(new ChatMemoryStoreProvider());
generalAssistantFactory.create(platformType, assistantConfigProvider);
generalAssistantFactory.createAIService(assistantConfigProvider, null);
generalAssistantFactory.createForTest(assistantConfigProvider, null);
}
}
}
Loading

0 comments on commit 5529204

Please sign in to comment.