Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BIGTOP-4233: Add thread name API to Chatbot #76

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public AIAssistant createWithPrompt(

String systemPrompt = systemPromptProvider.getSystemMessage(systemPrompts);
aiAssistant.setSystemPrompt(systemPrompt);
if (id != null) {
aiAssistant.setThreadNameGenerator(systemPromptProvider.getThreadNameGenerator());
}
String locale = assistantConfig.getLanguage();
if (locale != null) {
aiAssistant.setSystemPrompt(systemPromptProvider.getLanguagePrompt(locale));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,4 +82,8 @@ public String getLanguagePrompt(String locale) {
return text;
}
}

public String getThreadNameGenerator() {
return getSystemMessage(SystemPrompt.THREAD_NAME_GENERATOR);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Please generate a name for our conversation, no more than 10 words, and only answer the generated name without saying anything else.
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,65 @@
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;
import org.apache.bigtop.manager.ai.core.provider.AIAssistantConfigProvider;

import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.memory.ChatMemory;
import dev.langchain4j.store.memory.chat.ChatMemoryStore;

public abstract class AbstractAIAssistant implements AIAssistant {

protected static final Integer MEMORY_LEN = 10;
protected static final Integer THREAD_NAME_LEN = 100;
protected final ChatMemory chatMemory;
private String threadNameGenerator;

protected AbstractAIAssistant(ChatMemory chatMemory) {
this.chatMemory = chatMemory;
}

@Override
public String ask(String chatMessage) {
chatMemory.add(UserMessage.from(chatMessage));
String aiMessage = runAsk(chatMessage);
chatMemory.add(AiMessage.from(aiMessage));
return aiMessage;
}

@Override
public boolean test() {
return ask("1+1=") != null;
return runAsk("1+1=") != null;
}

@Override
public String getThreadName() {
if (threadNameGenerator == null) {
return null;
}
boolean hasUserMessage = false;
for (ChatMessage message : chatMemory.messages()) {
if (message instanceof UserMessage) {
hasUserMessage = true;
break;
}
}
if (!hasUserMessage) {
return null;
}
String threadName = runAsk(threadNameGenerator);
return threadName.length() > THREAD_NAME_LEN ? threadName.substring(0, THREAD_NAME_LEN) : threadName;
}

@Override
public Object getId() {
return chatMemory.id();
}

@Override
public void setThreadNameGenerator(String threadNameGenerator) {
this.threadNameGenerator = threadNameGenerator;
}

public abstract static class Builder implements AIAssistant.Builder {
protected Object id;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
public enum SystemPrompt {
DEFAULT_PROMPT("default"),
BIGDATA_PROFESSOR("big-data-professor"),
LANGUAGE_PROMPT("language-prompt");
LANGUAGE_PROMPT("language-prompt"),
THREAD_NAME_GENERATOR("thread-name-generator");
;

private final String value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public interface AIAssistant {
*/
String ask(String userMessage);

/**
* Have a conversation that will not be saved.
* @param message
* @return
*/
String runAsk(String message);
/**
* This is used to get the AIAssistant's Platform
* @return
Expand All @@ -76,6 +82,18 @@ default Map<String, String> createThread() {
*/
boolean test();

/**
* Set prompt for generating thread name.
* @return
*/
void setThreadNameGenerator(String threadNameGenerator);

/**
* Get the name of the chat thread
* @return
*/
String getThreadName();

interface Builder {
Builder id(Object id);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,6 @@ public interface SystemPromptProvider {
String getSystemMessage();

String getLanguagePrompt(String locale);

String getThreadNameGenerator();
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
package org.apache.bigtop.manager.ai.dashscope;

import org.apache.bigtop.manager.ai.core.AbstractAIAssistant;
import org.apache.bigtop.manager.ai.core.enums.MessageType;
import org.apache.bigtop.manager.ai.core.enums.PlatformType;
import org.apache.bigtop.manager.ai.core.factory.AIAssistant;

Expand Down Expand Up @@ -50,7 +49,6 @@
import com.alibaba.dashscope.threads.runs.RunParam;
import com.alibaba.dashscope.threads.runs.Runs;
import dev.langchain4j.data.message.AiMessage;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.ValidationUtils;
Expand Down Expand Up @@ -88,20 +86,6 @@ private String getValueFromAssistantStreamMessage(AssistantStreamMessage assista
return streamMessage.toString();
}

private void addMessage(String message, MessageType sender) {
ChatMessage chatMessage;
if (sender.equals(MessageType.AI)) {
chatMessage = new AiMessage(message);
} else if (sender.equals(MessageType.USER)) {
chatMessage = new UserMessage(message);
} else if (sender.equals(MessageType.SYSTEM)) {
chatMessage = new SystemMessage(message);
} else {
return;
}
chatMemory.add(chatMessage);
}

@Override
public PlatformType getPlatform() {
return PlatformType.DASH_SCOPE;
Expand Down Expand Up @@ -131,7 +115,7 @@ public void setSystemPrompt(String systemPrompt) {
} catch (NoApiKeyException | InputRequiredException | InvalidateParameter e) {
throw new RuntimeException(e);
}
addMessage(systemPrompt, MessageType.SYSTEM);
chatMemory.add(SystemMessage.systemMessage(systemPrompt));
}

public static Builder builder() {
Expand All @@ -140,7 +124,7 @@ public static Builder builder() {

@Override
public Flux<String> streamAsk(String userMessage) {
addMessage(userMessage, MessageType.USER);
chatMemory.add(UserMessage.from(userMessage));
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -174,13 +158,11 @@ public Flux<String> streamAsk(String userMessage) {
return message;
})
.doOnComplete(() -> {
addMessage(finalMessage.toString(), MessageType.AI);
chatMemory.add(AiMessage.from(finalMessage.toString()));
});
}

@Override
public String ask(String userMessage) {
addMessage(userMessage, MessageType.USER);
public String runAsk(String userMessage) {
TextMessageParam textMessageParam = TextMessageParam.builder()
.apiKey(dashScopeThreadParam.getApiKey())
.role(Role.USER.getValue())
Expand Down Expand Up @@ -244,7 +226,6 @@ public String ask(String userMessage) {
ContentText contentText = (ContentText) content;
finalMessage.append(contentText.getText().getValue());
}
addMessage(finalMessage.toString(), MessageType.AI);
return finalMessage.toString();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,9 @@ public void onComplete(Response<AiMessage> response) {
}

@Override
public String ask(String chatMessage) {
chatMemory.add(UserMessage.from(chatMessage));
public String runAsk(String chatMessage) {
Response<AiMessage> generate = chatLanguageModel.generate(chatMemory.messages());
String aiMessage = generate.content().text();
chatMemory.add(AiMessage.from(aiMessage));
return aiMessage;
return generate.content().text();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,9 @@ public void onComplete(Response<AiMessage> response) {
}

@Override
public String ask(String chatMessage) {
chatMemory.add(UserMessage.from(chatMessage));
public String runAsk(String chatMessage) {
Response<AiMessage> generate = chatLanguageModel.generate(chatMemory.messages());
String aiMessage = generate.content().text();
chatMemory.add(AiMessage.from(aiMessage));
return aiMessage;
return generate.content().text();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ public class ChatThreadPO extends BasePO implements Serializable {
@Column(name = "model", nullable = false, length = 255)
private String model;

@Column(name = "name", length = 255)
private String name;

@Column(name = "thread_info", columnDefinition = "json")
private Map<String, String> threadInfo;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,10 @@ public static <Entity> String updateList(
primaryKey = keywordsFormat(entry.getValue(), DBType.MYSQL);
continue;
}

Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field == null || checkBaseField(field)) {
continue;
}
StringBuilder caseClause = new StringBuilder();
caseClause
.append(keywordsFormat(entry.getValue(), DBType.MYSQL))
Expand All @@ -301,10 +304,6 @@ public static <Entity> String updateList(
if (ps == null || ps.getReadMethod() == null) {
continue;
}
Field field = ReflectionUtils.findField(entityClass, entry.getKey());
if (field == null || checkBaseField(field)) {
continue;
}
Object value = ReflectionUtils.invokeMethod(ps.getReadMethod(), entity);
PropertyDescriptor pkPs =
BeanUtils.getPropertyDescriptor(entityClass, tableMetaData.getPkProperty());
Expand Down Expand Up @@ -335,8 +334,12 @@ public static <Entity> String updateList(
.append("' THEN NULL ");
}
}
if (caseClause.toString().endsWith("CASE ")) {
caseClause.append("WHEN TRUE THEN ");
} else {
caseClause.append("ELSE ");
}
caseClause
.append("ELSE ")
.append(keywordsFormat(entry.getValue(), DBType.MYSQL))
.append(" ");
caseClause.append("END");
Expand All @@ -360,9 +363,7 @@ public static <Entity> String updateList(
case POSTGRESQL: {
sqlBuilder
.append("UPDATE ")
.append("\"")
.append(tableMetaData.getTableName())
.append("\"")
.append(keywordsFormat(tableMetaData.getTableName(), DBType.POSTGRESQL))
.append(" SET ");
Map<String, StringBuilder> setClauses = new LinkedHashMap<>();
String primaryKey = keywordsFormat("id", DBType.POSTGRESQL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.bigtop.manager.server.model.converter.PlatformConverter;
import org.apache.bigtop.manager.server.model.dto.PlatformDTO;
import org.apache.bigtop.manager.server.model.req.ChatbotMessageReq;
import org.apache.bigtop.manager.server.model.req.ChatbotThreadReq;
import org.apache.bigtop.manager.server.model.req.PlatformReq;
import org.apache.bigtop.manager.server.model.vo.ChatMessageVO;
import org.apache.bigtop.manager.server.model.vo.ChatThreadVO;
Expand Down Expand Up @@ -112,6 +113,19 @@ public SseEmitter talk(
return chatbotService.talk(platformId, threadId, messageReq.getMessage());
}

@Operation(summary = "get name", description = "Get name of the thread")
@GetMapping("platforms/{platformId}/threads/{threadId}/name")
public ResponseEntity<ChatThreadVO> getThreadName(@PathVariable Long platformId, @PathVariable Long threadId) {
return ResponseEntity.success(chatbotService.getThreadName(platformId, threadId));
}

@Operation(summary = "get name", description = "Get name of the thread")
@PostMapping("platforms/{platformId}/threads/{threadId}/name")
public ResponseEntity<Boolean> setThreadName(
@PathVariable Long platformId, @PathVariable Long threadId, @RequestBody ChatbotThreadReq threadReq) {
return ResponseEntity.success(chatbotService.setThreadName(platformId, threadId, threadReq.getNewName()));
}

@Operation(summary = "history", description = "Get chat records")
@GetMapping("platforms/{platformId}/threads/{threadId}/history")
public ResponseEntity<List<ChatMessageVO>> history(@PathVariable Long platformId, @PathVariable Long threadId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ public enum ApiExceptionEnum {
CREDIT_INCORRECT(19003, LocaleKeys.CREDIT_INCORRECT),
MODEL_NOT_SUPPORTED(19004, LocaleKeys.MODEL_NOT_SUPPORTED),
CHAT_THREAD_NOT_FOUND(19005, LocaleKeys.CHAT_THREAD_NOT_FOUND),
THREAD_NAME_TOO_LONG(19006, LocaleKeys.THREAD_NAME_TOO_LONG),
;

private final Integer code;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ public enum LocaleKeys {
CREDIT_INCORRECT("credit.incorrect"),
MODEL_NOT_SUPPORTED("model.not.supported"),
CHAT_THREAD_NOT_FOUND("chat.thread.not.found"),
THREAD_NAME_TOO_LONG("thread.name.too.long"),

CHAT_LANGUAGE_PROMPT("chat.language.prompt"),
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ public interface ChatThreadConverter {
ChatThreadConverter INSTANCE = Mappers.getMapper(ChatThreadConverter.class);

@Mapping(source = "id", target = "threadId")
@Mapping(source = "name", target = "threadName")
ChatThreadVO fromPO2VO(ChatThreadPO platformAuthorizedPO);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.bigtop.manager.server.model.req;

import lombok.Data;

import jakarta.validation.constraints.NotEmpty;

@Data
public class ChatbotThreadReq {
@NotEmpty
private String newName;
}
Loading
Loading