Skip to content

Commit 658e712

Browse files
committed
enable stream response to improve the chat experience
1 parent 28ad60f commit 658e712

File tree

13 files changed

+276
-130
lines changed

13 files changed

+276
-130
lines changed

pom.xml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>org.springframework.boot</groupId>
88
<artifactId>spring-boot-starter-parent</artifactId>
9-
<version>3.1.2</version>
9+
<version>3.1.3</version>
1010
<relativePath/> <!-- lookup parent from repository -->
1111
</parent>
1212
<groupId>com.apolloconfig.apollo.ai</groupId>
@@ -16,10 +16,10 @@
1616
<description>a smart qa bot</description>
1717
<properties>
1818
<java.version>17</java.version>
19-
<openai-gpt3-java.version>0.12.0</openai-gpt3-java.version>
20-
<guava.version>31.1-jre</guava.version>
21-
<flexmark.version>0.62.2</flexmark.version>
22-
<milvus.version>2.2.7</milvus.version>
19+
<openai-gpt3-java.version>0.16.0</openai-gpt3-java.version>
20+
<guava.version>32.1.2-jre</guava.version>
21+
<flexmark.version>0.64.8</flexmark.version>
22+
<milvus.version>2.3.0</milvus.version>
2323
</properties>
2424

2525
<dependencyManagement>
@@ -50,7 +50,7 @@
5050
<dependencies>
5151
<dependency>
5252
<groupId>org.springframework.boot</groupId>
53-
<artifactId>spring-boot-starter-web</artifactId>
53+
<artifactId>spring-boot-starter-webflux</artifactId>
5454
</dependency>
5555
<dependency>
5656
<groupId>org.springframework.boot</groupId>

src/main/java/com/apolloconfig/apollo/ai/qabot/api/AiService.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
package com.apolloconfig.apollo.ai.qabot.api;
22

3+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
34
import com.theokanning.openai.completion.chat.ChatMessage;
45
import com.theokanning.openai.embedding.Embedding;
6+
import io.reactivex.Flowable;
57
import java.util.List;
68

79
public interface AiService {
810

9-
String getCompletion(String prompt);
11+
Flowable<ChatCompletionChunk> getCompletion(String prompt);
1012

11-
String getCompletionFromMessages(List<ChatMessage> messages);
13+
Flowable<ChatCompletionChunk> getCompletionFromMessages(List<ChatMessage> messages);
1214

1315
List<Embedding> getEmbeddings(List<String> chunks);
1416

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
package com.apolloconfig.apollo.ai.qabot.config;
22

3+
import java.util.concurrent.TimeUnit;
34
import org.springframework.context.annotation.Configuration;
4-
import org.springframework.web.servlet.config.annotation.CorsRegistry;
5-
import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
5+
import org.springframework.http.CacheControl;
6+
import org.springframework.web.reactive.config.CorsRegistry;
7+
import org.springframework.web.reactive.config.ResourceHandlerRegistry;
8+
import org.springframework.web.reactive.config.WebFluxConfigurer;
69

710
@Configuration
8-
public class WebConfig implements WebMvcConfigurer {
11+
public class WebConfig implements WebFluxConfigurer {
912

1013
private final CorsProperties corsProperties;
1114

@@ -15,9 +18,15 @@ public WebConfig(CorsProperties corsProperties) {
1518

1619
@Override
1720
public void addCorsMappings(CorsRegistry registry) {
18-
registry.addMapping("/**")
21+
registry.addMapping("/qa/**")
1922
.allowedOrigins(corsProperties.getAllowedOrigins())
20-
.allowedMethods("POST");
23+
.allowedMethods("GET", "POST");
2124
}
2225

26+
@Override
27+
public void addResourceHandlers(ResourceHandlerRegistry registry) {
28+
registry.addResourceHandler("/**")
29+
.addResourceLocations("classpath:/static/")
30+
.setCacheControl(CacheControl.maxAge(1, TimeUnit.DAYS));
31+
}
2332
}

src/main/java/com/apolloconfig/apollo/ai/qabot/controller/HelloController.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import com.apolloconfig.apollo.ai.qabot.api.AiService;
44
import com.google.common.collect.Lists;
5+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
56
import com.theokanning.openai.completion.chat.ChatMessage;
67
import com.theokanning.openai.completion.chat.ChatMessageRole;
8+
import io.reactivex.Flowable;
79
import org.springframework.web.bind.annotation.GetMapping;
810
import org.springframework.web.bind.annotation.PathVariable;
911
import org.springframework.web.bind.annotation.RequestMapping;
@@ -20,13 +22,16 @@ public HelloController(AiService aiService) {
2022
}
2123

2224
@GetMapping("/{name}")
23-
public String hello(@PathVariable String name) {
25+
public Flowable<String> hello(@PathVariable String name) {
2426
ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(),
2527
"You are an assistant who responds in the style of Dr Seuss.");
2628
ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(),
2729
"write a brief greeting for " + name);
2830

29-
return aiService.getCompletionFromMessages(
31+
Flowable<ChatCompletionChunk> result = aiService.getCompletionFromMessages(
3032
Lists.newArrayList(systemMessage, userMessage));
33+
return result.filter(chatCompletionChunk ->
34+
chatCompletionChunk.getChoices().get(0).getMessage().getContent() != null).map(
35+
chatCompletionChunk -> chatCompletionChunk.getChoices().get(0).getMessage().getContent());
3136
}
3237
}
Lines changed: 86 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
11
package com.apolloconfig.apollo.ai.qabot.controller;
22

33
import com.apolloconfig.apollo.ai.qabot.api.AiService;
4+
import com.apolloconfig.apollo.ai.qabot.api.VectorDBService;
45
import com.apolloconfig.apollo.ai.qabot.markdown.MarkdownSearchResult;
56
import com.google.common.base.Strings;
67
import com.google.common.collect.Lists;
7-
import com.apolloconfig.apollo.ai.qabot.api.VectorDBService;
8+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
89
import com.theokanning.openai.embedding.Embedding;
10+
import io.reactivex.Flowable;
911
import java.util.Collections;
1012
import java.util.List;
1113
import java.util.Set;
14+
import java.util.concurrent.atomic.AtomicInteger;
1215
import java.util.stream.Collectors;
1316
import org.slf4j.Logger;
1417
import org.slf4j.LoggerFactory;
1518
import org.springframework.beans.factory.annotation.Value;
19+
import org.springframework.http.MediaType;
20+
import org.springframework.web.bind.annotation.GetMapping;
1621
import org.springframework.web.bind.annotation.PostMapping;
1722
import org.springframework.web.bind.annotation.RequestMapping;
1823
import org.springframework.web.bind.annotation.RequestParam;
1924
import org.springframework.web.bind.annotation.RestController;
25+
import org.springframework.web.server.ServerWebExchange;
26+
import reactor.core.publisher.Flux;
27+
import reactor.core.publisher.Mono;
2028

2129
@RestController
2230
@RequestMapping("/qa")
@@ -38,54 +46,109 @@ public QAController(AiService aiService, VectorDBService vectorDBService) {
3846
this.vectorDBService = vectorDBService;
3947
}
4048

41-
@PostMapping
42-
public Answer qa(@RequestParam String question) {
49+
@GetMapping(produces = MediaType.TEXT_EVENT_STREAM_VALUE)
50+
public Flux<Answer> qa(@RequestParam String question) {
4351
question = question.trim();
4452
if (Strings.isNullOrEmpty(question)) {
45-
return Answer.EMPTY;
53+
return Flux.just(Answer.EMPTY);
4654
}
4755

4856
try {
4957
return doQA(question);
5058
} catch (Throwable exception) {
5159
LOGGER.error("Error while calling OpenAI API", exception);
52-
return Answer.ERROR;
60+
return Flux.just(Answer.ERROR);
5361
}
5462
}
5563

56-
private Answer doQA(String question) {
57-
List<Embedding> embeddings = aiService.getEmbeddings(Lists.newArrayList(question));
64+
/**
65+
* @deprecated Use {@link #qa(String)} instead.
66+
*/
67+
@Deprecated
68+
@PostMapping
69+
public Mono<Answer> qaSync(ServerWebExchange serverWebExchange) {
70+
Mono<String> field = getFormField(serverWebExchange, "question");
71+
return field.flatMap(question -> {
72+
if (Strings.isNullOrEmpty(question)) {
73+
return Mono.just(Answer.EMPTY);
74+
}
75+
76+
try {
77+
Flux<Answer> answer = doQA(question.trim());
78+
return answer.reduce((a1, a2) -> {
79+
if (Answer.END.answer().equals(a2.answer())) {
80+
return a1;
81+
}
82+
a1.relatedFiles().addAll(a2.relatedFiles);
83+
84+
return new Answer(a1.answer() + a2.answer(), a1.relatedFiles);
85+
});
86+
} catch (Throwable exception) {
87+
LOGGER.error("Error while calling OpenAI API", exception);
88+
return Mono.just(Answer.ERROR);
89+
}
90+
});
91+
}
5892

59-
List<List<Float>> searchVectors = Collections.singletonList(
60-
embeddings.get(0).getEmbedding().stream()
61-
.map(Double::floatValue).collect(Collectors.toList()));
93+
private Mono<String> getFormField(ServerWebExchange exchange, String fieldName) {
94+
return exchange.getFormData()
95+
.flatMap(data -> Mono.justOrEmpty(data.getFirst(fieldName)));
96+
}
6297

63-
List<MarkdownSearchResult> searchResults = vectorDBService.search(searchVectors, topK);
98+
private Flux<Answer> doQA(String question) {
99+
List<MarkdownSearchResult> searchResults = searchFromVectorDB(question);
64100

65101
if (searchResults.isEmpty()) {
66-
return Answer.UNKNOWN;
102+
return Flux.just(Answer.UNKNOWN);
67103
}
68104

69105
Set<String> relatedFiles = searchResults.stream()
70106
.map(MarkdownSearchResult::getFileRoot).collect(Collectors.toSet());
71107

72-
StringBuilder sb = new StringBuilder();
73-
searchResults.forEach(
74-
markdownSearchResult -> sb.append(markdownSearchResult.getContent()).append("\n"));
75-
76-
String promptMessage = prompt.replace("{question}", question)
77-
.replace("{context}", sb.toString());
108+
String promptMessage = assemblePromptMessage(searchResults, question);
78109

79-
String answer = aiService.getCompletion(promptMessage);
110+
Flowable<ChatCompletionChunk> result = aiService.getCompletion(promptMessage);
80111

81112
if (LOGGER.isDebugEnabled()) {
82-
LOGGER.debug("\nPrompt message: {}\nAnswer: {}", promptMessage, answer);
113+
LOGGER.debug("\nPrompt message: {}", promptMessage);
83114
}
84115

85-
return new Answer(answer, relatedFiles);
116+
final AtomicInteger counter = new AtomicInteger();
117+
Flux<Answer> flux = Flux.from(result.filter(
118+
chatCompletionChunk -> chatCompletionChunk.getChoices().get(0).getMessage().getContent()
119+
!= null).map(chatCompletionChunk -> {
120+
String value = chatCompletionChunk.getChoices().get(0).getMessage().getContent();
121+
if (LOGGER.isDebugEnabled()) {
122+
System.out.print(value);
123+
}
124+
125+
return counter.incrementAndGet() == 1 ? new Answer(value, relatedFiles)
126+
: new Answer(value, Collections.emptySet());
127+
}));
128+
129+
return flux.concatWith(Flux.just(Answer.END));
130+
}
131+
132+
private List<MarkdownSearchResult> searchFromVectorDB(String question) {
133+
List<Embedding> embeddings = aiService.getEmbeddings(Lists.newArrayList(question));
134+
135+
List<List<Float>> searchVectors = Collections.singletonList(
136+
embeddings.get(0).getEmbedding().stream()
137+
.map(Double::floatValue).collect(Collectors.toList()));
138+
139+
return vectorDBService.search(searchVectors, topK);
86140
}
87141

88-
static class Answer {
142+
private String assemblePromptMessage(List<MarkdownSearchResult> searchResults, String question) {
143+
StringBuilder sb = new StringBuilder();
144+
searchResults.forEach(
145+
markdownSearchResult -> sb.append(markdownSearchResult.getContent()).append("\n"));
146+
147+
return prompt.replace("{question}", question)
148+
.replace("{context}", sb.toString());
149+
}
150+
151+
public record Answer(String answer, Set<String> relatedFiles) {
89152

90153
static final Answer EMPTY = new Answer("", Collections.emptySet());
91154
static final Answer UNKNOWN = new Answer("Sorry, I don't know the answer.",
@@ -95,20 +158,6 @@ static class Answer {
95158
"Sorry, I can't answer your question right now. Please try again later.",
96159
Collections.emptySet());
97160

98-
private final String answer;
99-
private final Set<String> relatedFiles;
100-
101-
public Answer(String answer, Set<String> relatedFiles) {
102-
this.answer = answer;
103-
this.relatedFiles = relatedFiles;
104-
}
105-
106-
public String getAnswer() {
107-
return answer;
108-
}
109-
110-
public Set<String> getRelatedFiles() {
111-
return relatedFiles;
112-
}
161+
static final Answer END = new Answer("$END$", Collections.emptySet());
113162
}
114163
}

src/main/java/com/apolloconfig/apollo/ai/qabot/markdown/MarkdownProcessor.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import com.apolloconfig.apollo.ai.qabot.api.VectorDBService;
55
import com.apolloconfig.apollo.ai.qabot.config.MarkdownFilesConfig;
66
import com.apolloconfig.apollo.ai.qabot.config.MarkdownProcessorRetryConfig;
7-
import com.google.common.collect.Maps;
87
import com.theokanning.openai.embedding.Embedding;
98
import com.vladsch.flexmark.ast.Heading;
109
import com.vladsch.flexmark.parser.Parser;
@@ -18,7 +17,6 @@
1817
import java.security.NoSuchAlgorithmException;
1918
import java.util.ArrayList;
2019
import java.util.List;
21-
import java.util.Map;
2220
import java.util.Objects;
2321
import java.util.stream.Stream;
2422
import org.slf4j.Logger;

src/main/java/com/apolloconfig/apollo/ai/qabot/openai/OpenAiService.java

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
package com.apolloconfig.apollo.ai.qabot.openai;
22

3-
import com.google.common.collect.Lists;
43
import com.apolloconfig.apollo.ai.qabot.api.AiService;
4+
import com.google.common.collect.Lists;
5+
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
56
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
67
import com.theokanning.openai.completion.chat.ChatMessage;
78
import com.theokanning.openai.completion.chat.ChatMessageRole;
89
import com.theokanning.openai.embedding.Embedding;
910
import com.theokanning.openai.embedding.EmbeddingRequest;
11+
import io.reactivex.Flowable;
1012
import java.util.List;
1113
import org.springframework.context.annotation.Profile;
1214
import org.springframework.stereotype.Component;
@@ -24,20 +26,22 @@ public OpenAiService() {
2426
service = OpenAiServiceFactory.getService(System.getenv("OPENAI_API_KEY"));
2527
}
2628

27-
public String getCompletion(String prompt) {
29+
public Flowable<ChatCompletionChunk> getCompletion(String prompt) {
2830
ChatMessage message = new ChatMessage(ChatMessageRole.USER.value(), prompt);
2931
return getCompletionFromMessages(Lists.newArrayList(message));
3032
}
3133

32-
public String getCompletionFromMessages(List<ChatMessage> messages) {
34+
public Flowable<ChatCompletionChunk> getCompletionFromMessages(List<ChatMessage> messages) {
3335
return getCompletionFromMessages(messages, 0.0);
3436
}
3537

36-
public String getCompletionFromMessages(List<ChatMessage> messages, double temperature) {
38+
public Flowable<ChatCompletionChunk> getCompletionFromMessages(List<ChatMessage> messages,
39+
double temperature) {
3740
return getCompletionFromMessages(messages, DEFAULT_MODEL, temperature, 500);
3841
}
3942

40-
public String getCompletionFromMessages(List<ChatMessage> messages, String model,
43+
public Flowable<ChatCompletionChunk> getCompletionFromMessages(List<ChatMessage> messages,
44+
String model,
4145
double temperature, int maxTokens) {
4246
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
4347
.builder()
@@ -47,8 +51,7 @@ public String getCompletionFromMessages(List<ChatMessage> messages, String model
4751
.maxTokens(maxTokens)
4852
.build();
4953

50-
return service.createChatCompletion(chatCompletionRequest).getChoices().get(0).getMessage()
51-
.getContent();
54+
return service.streamChatCompletion(chatCompletionRequest);
5255
}
5356

5457
public List<Embedding> getEmbeddings(List<String> chunks) {

src/main/java/com/apolloconfig/apollo/ai/qabot/openai/OpenAiServiceFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import com.fasterxml.jackson.databind.ObjectMapper;
88
import com.google.common.base.Strings;
99
import com.google.common.collect.Maps;
10-
import com.theokanning.openai.OpenAiApi;
10+
import com.theokanning.openai.client.OpenAiApi;
1111
import com.theokanning.openai.service.OpenAiService;
1212
import java.io.IOException;
1313
import java.net.InetAddress;

0 commit comments

Comments
 (0)