Skip to content

Commit

Permalink
rename messageSender
Browse files Browse the repository at this point in the history
  • Loading branch information
lhpqaq committed Sep 14, 2024
1 parent 6087fef commit 8be1500
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.bigtop.manager.ai.assistant.store;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.dao.po.ChatThreadPO;
import org.apache.bigtop.manager.dao.repository.ChatMessageDao;
Expand All @@ -27,7 +27,6 @@
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.ChatMessageType;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

Expand All @@ -47,12 +46,10 @@ public PersistentChatMemoryStore(ChatThreadDao chatThreadDao, ChatMessageDao cha

private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
String sender = chatMessagePO.getSender().toLowerCase();
if (sender.equals(MessageSender.AI.getValue())) {
if (sender.equals(MessageType.AI.getValue())) {
return new AiMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.USER.getValue())) {
} else if (sender.equals(MessageType.USER.getValue())) {
return new UserMessage(chatMessagePO.getMessage());
} else if (sender.equals(MessageSender.SYSTEM.getValue())) {
return new SystemMessage(chatMessagePO.getMessage());
} else {
return null;
}
Expand All @@ -61,19 +58,15 @@ private ChatMessage convertToChatMessage(ChatMessagePO chatMessagePO) {
private ChatMessagePO convertToChatMessagePO(ChatMessage chatMessage, Long chatThreadId) {
ChatMessagePO chatMessagePO = new ChatMessagePO();
if (chatMessage.type().equals(ChatMessageType.AI)) {
chatMessagePO.setSender(MessageSender.AI.getValue());
chatMessagePO.setSender(MessageType.AI.getValue());
AiMessage aiMessage = (AiMessage) chatMessage;
chatMessagePO.setMessage(aiMessage.text());
} else if (chatMessage.type().equals(ChatMessageType.USER)) {
chatMessagePO.setSender(MessageSender.USER.getValue());
chatMessagePO.setSender(MessageType.USER.getValue());
UserMessage userMessage = (UserMessage) chatMessage;
chatMessagePO.setMessage(userMessage.singleText());
} else if (chatMessage.type().equals(ChatMessageType.SYSTEM)) {
chatMessagePO.setSender(MessageSender.SYSTEM.getValue());
SystemMessage systemMessage = (SystemMessage) chatMessage;
chatMessagePO.setMessage(systemMessage.text());
} else {
chatMessagePO.setSender(chatMessage.type().toString());
return null;
}
ChatThreadPO chatThreadPO = chatThreadDao.findByThreadId(chatThreadId);
chatMessagePO.setUserId(chatThreadPO.getUserId());
Expand All @@ -94,6 +87,9 @@ public List<ChatMessage> getMessages(Object threadId) {
@Override
public void updateMessages(Object threadId, List<ChatMessage> messages) {
ChatMessagePO chatMessagePO = convertToChatMessagePO(messages.get(messages.size() - 1), (Long) threadId);
if (chatMessagePO == null) {
return;
}
chatMessageDao.save(chatMessagePO);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,28 @@
import java.util.stream.Collectors;

@Getter
public enum MessageSender {
public enum MessageType {
USER("user"),
AI("ai"),
SYSTEM("system");

private final String value;

MessageSender(String value) {
MessageType(String value) {
this.value = value;
}

public static List<String> getSenders() {
return Arrays.stream(values()).map(item -> item.value).collect(Collectors.toList());
}

public static MessageSender getMessageSender(String value) {
public static MessageType getMessageSender(String value) {
if (Objects.isNull(value) || value.isEmpty()) {
return null;
}
for (MessageSender messageSender : MessageSender.values()) {
if (messageSender.value.equals(value)) {
return messageSender;
for (MessageType messageType : MessageType.values()) {
if (messageType.value.equals(value)) {
return messageType;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
package org.apache.bigtop.manager.ai.dashscope;

import org.apache.bigtop.manager.ai.core.AbstractAIAssistant;
import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;

Expand Down Expand Up @@ -88,13 +88,13 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista
return streamMessage.toString();
}

private void saveMessage(String message, MessageSender sender) {
private void saveMessage(String message, MessageType sender) {
ChatMessage chatMessage;
if (sender.equals(MessageSender.AI)) {
if (sender.equals(MessageType.AI)) {
chatMessage = new AiMessage(message);
} else if (sender.equals(MessageSender.USER)) {
} else if (sender.equals(MessageType.USER)) {
chatMessage = new UserMessage(message);
} else if (sender.equals(MessageSender.SYSTEM)) {
} else if (sender.equals(MessageType.SYSTEM)) {
chatMessage = new SystemMessage(message);
} else {
return;
Expand Down Expand Up @@ -131,7 +131,7 @@ public void setSystemPrompt(String systemPrompt) {
} catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) {
throw new RuntimeException(e);
}
saveMessage(systemPrompt, MessageSender.SYSTEM);
saveMessage(systemPrompt, MessageType.SYSTEM);
}

public static Builder builder() {
Expand All @@ -140,7 +140,7 @@ public static Builder builder() {

@Override
public Flux<String> streamAsk(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -174,13 +174,13 @@ public Flux<String> streamAsk(String userMessage) {
return message;
})
.doOnComplete(() -> {
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
});
}

@Override
public String ask(String userMessage) {
saveMessage(userMessage, MessageSender.USER);
saveMessage(userMessage, MessageType.USER);
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -244,7 +244,7 @@ public String ask(String userMessage) {
ContentText contentText = (ContentText) content;
finalMessage.append(contentText.getText().getValue());
}
saveMessage(finalMessage.toString(), MessageSender.AI);
saveMessage(finalMessage.toString(), MessageType.AI);
return finalMessage.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.apache.bigtop.manager.server.model.converter;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.dao.po.ChatMessagePO;
import org.apache.bigtop.manager.server.config.MapStructSharedConfig;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
Expand All @@ -32,12 +32,12 @@ public interface ChatMessageConverter {

ChatMessageVO fromPO2VO(ChatMessagePO chatMessagePO);

default MessageSender mapStringToMessageSender(String sender) {
default MessageType mapStringToMessageSender(String sender) {
if (sender == null) {
return null;
}
try {
return MessageSender.valueOf(sender.toUpperCase());
return MessageType.valueOf(sender.toUpperCase());
} catch (IllegalArgumentException e) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
*/
package org.apache.bigtop.manager.server.model.vo;

import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;

import lombok.Data;

@Data
public class ChatMessageVO {
private MessageSender sender;
private MessageType sender;

private String message;

private String createTime;

public ChatMessageVO(MessageSender sender, String messageText, String createTime) {
public ChatMessageVO(MessageType sender, String messageText, String createTime) {
this.sender = sender;
this.message = messageText;
this.createTime = createTime;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.apache.bigtop.manager.ai.assistant.GeneralAssistantFactory;
import org.apache.bigtop.manager.ai.assistant.provider.AIAssistantConfig;
import org.apache.bigtop.manager.ai.assistant.store.PersistentChatMemoryStore;
import org.apache.bigtop.manager.ai.core.enums.MessageSender;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.factory.AIAssistantFactory;
Expand Down Expand Up @@ -325,11 +325,11 @@ public List<ChatMessageVO> history(Long platformId, Long threadId) {
List<ChatMessagePO> chatMessagePOs = chatMessageDao.findAllByThreadId(threadId);
for (ChatMessagePO chatMessagePO : chatMessagePOs) {
ChatMessageVO chatMessageVO = ChatMessageConverter.INSTANCE.fromPO2VO(chatMessagePO);
MessageSender sender = chatMessageVO.getSender();
MessageType sender = chatMessageVO.getSender();
if (sender == null) {
continue;
}
if (sender.equals(MessageSender.USER) || sender.equals(MessageSender.AI)) {
if (sender.equals(MessageType.USER) || sender.equals(MessageType.AI)) {
chatMessages.add(chatMessageVO);
}
}
Expand Down

0 comments on commit 8be1500

Please sign in to comment.