Skip to content

Commit

Permalink
Adjust the code order
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Dec 30, 2024
1 parent f419889 commit 4fcc13d
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 218 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ private void configureSystemPrompt(AIAssistant.Builder builder, SystemPrompt sys
if (systemPrompt != null) {
systemPrompts.add(locSystemPromptProvider.getSystemMessage(systemPrompt));
}

if (locale != null) {
systemPrompts.add(locSystemPromptProvider.getLanguagePrompt(locale));
}
Expand All @@ -81,7 +80,7 @@ public AIAssistant createWithPrompt(
}

AIAssistant.Builder builder = initializeBuilder(platformType);
builder = builder.id(id)
builder.id(id)
.memoryStore(chatMemoryStoreProvider.createPersistentChatMemoryStore())
.withConfigProvider(generalAssistantConfig)
.withToolProvider(toolProvider);
Expand All @@ -97,7 +96,7 @@ public AIAssistant createForTest(AIAssistantConfig config, ToolProvider toolProv
PlatformType platformType = generalAssistantConfig.getPlatformType();
AIAssistant.Builder builder = initializeBuilder(platformType);

builder = builder.id(null)
builder.id(null)
.memoryStore(chatMemoryStoreProvider.createInMemoryChatMemoryStore())
.withConfigProvider(generalAssistantConfig)
.withToolProvider(toolProvider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ public class GeneralAssistantConfig implements AIAssistantConfig {
private final String language;
private final PlatformType platformType;
private final Map<String, String> credentials;
/**
* Platform extra configs are put here
*/
private final Map<String, String> configs;

private GeneralAssistantConfig(Builder builder) {
Expand Down Expand Up @@ -76,8 +73,6 @@ public static class Builder {
private final Map<String, String> credentials = new HashMap<>();
private final Map<String, String> configs = new HashMap<>();

public Builder() {}

public Builder setModel(String model) {
this.model = model;
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,58 +80,9 @@ public class ChatbotServiceImpl implements ChatbotService {
@Resource
private AIAssistantFactory aiAssistantFactory;

private AuthPlatformPO validateAndGetActiveAuthPlatform() {
AuthPlatformPO authPlatform = null;
List<AuthPlatformPO> authPlatformPOS = authPlatformDao.findAll();
for (AuthPlatformPO authPlatformPO : authPlatformPOS) {
if (AuthPlatformStatus.isActive(authPlatformPO.getStatus())) {
authPlatform = authPlatformPO;
}
}
if (authPlatform == null || authPlatform.getIsDeleted()) {
throw new ApiException(ApiExceptionEnum.NO_PLATFORM_IN_USE);
}
return authPlatform;
}

private ChatThreadPO validateAndGetChatThread(Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
if (chatThreadPO == null || chatThreadPO.getIsDeleted()) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
Long userId = SessionUserHolder.getUserId();
if (!chatThreadPO.getUserId().equals(userId)) {
throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED);
}
return chatThreadPO;
}

private GeneralAssistantConfig getAIAssistantConfig(
String platformName, String model, Map<String, String> credentials, Long id) {
return GeneralAssistantConfig.builder()
.setPlatformType(getPlatformType(platformName))
.setModel(model)
.setId(id)
.setLanguage(LocaleContextHolder.getLocale().toString())
.addCredentials(credentials)
.build();
}

private PlatformType getPlatformType(String platformName) {
return PlatformType.getPlatformType(platformName.toLowerCase());
}

private AIAssistant buildAIAssistant(
String platformName, String model, Map<String, String> credentials, Long threadId, ChatbotCommand command) {
return aiAssistantFactory.createAIService(
getAIAssistantConfig(platformName, model, credentials, threadId),
aiServiceToolsProvider.getToolsProvide(command));
}

@Override
public ChatThreadVO createChatThread(ChatThreadDTO chatThreadDTO) {
AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform();

PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());

chatThreadDTO.setPlatformId(platformPO.getId());
Expand All @@ -140,6 +91,7 @@ public ChatThreadVO createChatThread(ChatThreadDTO chatThreadDTO) {
ChatThreadPO chatThreadPO = ChatThreadConverter.INSTANCE.fromDTO2PO(chatThreadDTO);
chatThreadPO.setUserId(SessionUserHolder.getUserId());
chatThreadDao.save(chatThreadPO);

return ChatThreadConverter.INSTANCE.fromPO2VO(chatThreadPO, authPlatformPO, platformPO);
}

Expand Down Expand Up @@ -175,48 +127,6 @@ public List<ChatThreadVO> getAllChatThreads() {
return chatThreads;
}

private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) {
ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId);
AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform();

if (!authPlatformPO.getId().equals(chatThreadPO.getAuthId())) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE);
}

AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO);

PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());
return buildAIAssistant(
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
threadId,
command);
}

private void sendTalkVO(SseEmitter emitter, String content, String finishReason) {
try {
TalkVO talkVO = new TalkVO();
talkVO.setContent(content);
talkVO.setFinishReason(finishReason);
emitter.send(talkVO);
} catch (Exception e) {
log.error("Error sending data to SseEmitter", e);
emitter.completeWithError(e);
}
}

private void handleError(SseEmitter emitter, Throwable throwable) {
log.error("Error during SSE streaming: {}", throwable.getMessage(), throwable);
sendTalkVO(emitter, null, "Error: " + throwable.getMessage());
emitter.completeWithError(throwable);
}

private void completeEmitter(SseEmitter emitter) {
sendTalkVO(emitter, null, "completed");
emitter.complete();
}

@Override
public SseEmitter talk(Long threadId, ChatbotCommand command, String message) {
AIAssistant aiAssistant = prepareTalk(threadId, command);
Expand Down Expand Up @@ -280,4 +190,94 @@ public ChatThreadVO getChatThread(Long threadId) {
return ChatThreadConverter.INSTANCE.fromPO2VO(
chatThreadPO, authPlatformPO, platformDao.findById(authPlatformPO.getPlatformId()));
}

private AuthPlatformPO validateAndGetActiveAuthPlatform() {
AuthPlatformPO authPlatform = null;
List<AuthPlatformPO> authPlatformPOS = authPlatformDao.findAll();
for (AuthPlatformPO authPlatformPO : authPlatformPOS) {
if (AuthPlatformStatus.isActive(authPlatformPO.getStatus())) {
authPlatform = authPlatformPO;
}
}
if (authPlatform == null || authPlatform.getIsDeleted()) {
throw new ApiException(ApiExceptionEnum.NO_PLATFORM_IN_USE);
}
return authPlatform;
}

private ChatThreadPO validateAndGetChatThread(Long threadId) {
ChatThreadPO chatThreadPO = chatThreadDao.findById(threadId);
if (chatThreadPO == null || chatThreadPO.getIsDeleted()) {
throw new ApiException(ApiExceptionEnum.CHAT_THREAD_NOT_FOUND);
}
Long userId = SessionUserHolder.getUserId();
if (!chatThreadPO.getUserId().equals(userId)) {
throw new ApiException(ApiExceptionEnum.PERMISSION_DENIED);
}
return chatThreadPO;
}

private GeneralAssistantConfig getAIAssistantConfig(
String platformName, String model, Map<String, String> credentials, Long id) {
return GeneralAssistantConfig.builder()
.setPlatformType(getPlatformType(platformName))
.setModel(model)
.setId(id)
.setLanguage(LocaleContextHolder.getLocale().toString())
.addCredentials(credentials)
.build();
}

private PlatformType getPlatformType(String platformName) {
return PlatformType.getPlatformType(platformName.toLowerCase());
}

private AIAssistant buildAIAssistant(
String platformName, String model, Map<String, String> credentials, Long threadId, ChatbotCommand command) {
return aiAssistantFactory.createAIService(
getAIAssistantConfig(platformName, model, credentials, threadId),
aiServiceToolsProvider.getToolsProvide(command));
}

private AIAssistant prepareTalk(Long threadId, ChatbotCommand command) {
ChatThreadPO chatThreadPO = validateAndGetChatThread(threadId);
AuthPlatformPO authPlatformPO = validateAndGetActiveAuthPlatform();

if (!authPlatformPO.getId().equals(chatThreadPO.getAuthId())) {
throw new ApiException(ApiExceptionEnum.PLATFORM_NOT_IN_USE);
}

AuthPlatformDTO authPlatformDTO = AuthPlatformConverter.INSTANCE.fromPO2DTO(authPlatformPO);
PlatformPO platformPO = platformDao.findById(authPlatformPO.getPlatformId());

return buildAIAssistant(
platformPO.getName(),
authPlatformDTO.getModel(),
authPlatformDTO.getAuthCredentials(),
threadId,
command);
}

private void sendTalkVO(SseEmitter emitter, String content, String finishReason) {
try {
TalkVO talkVO = new TalkVO();
talkVO.setContent(content);
talkVO.setFinishReason(finishReason);
emitter.send(talkVO);
} catch (Exception e) {
log.error("Error sending data to SseEmitter", e);
emitter.completeWithError(e);
}
}

private void handleError(SseEmitter emitter, Throwable throwable) {
log.error("Error during SSE streaming: {}", throwable.getMessage(), throwable);
sendTalkVO(emitter, null, "Error: " + throwable.getMessage());
emitter.completeWithError(throwable);
}

private void completeEmitter(SseEmitter emitter) {
sendTalkVO(emitter, null, "completed");
emitter.complete();
}
}
Loading

0 comments on commit 4fcc13d

Please sign in to comment.