1
1
package com .apolloconfig .apollo .ai .qabot .controller ;
2
2
3
3
import com .apolloconfig .apollo .ai .qabot .api .AiService ;
4
+ import com .apolloconfig .apollo .ai .qabot .api .VectorDBService ;
4
5
import com .apolloconfig .apollo .ai .qabot .markdown .MarkdownSearchResult ;
5
6
import com .google .common .base .Strings ;
6
7
import com .google .common .collect .Lists ;
7
- import com .apolloconfig . apollo . ai . qabot . api . VectorDBService ;
8
+ import com .theokanning . openai . completion . chat . ChatCompletionChunk ;
8
9
import com .theokanning .openai .embedding .Embedding ;
10
+ import io .reactivex .Flowable ;
9
11
import java .util .Collections ;
10
12
import java .util .List ;
11
13
import java .util .Set ;
14
+ import java .util .concurrent .atomic .AtomicInteger ;
12
15
import java .util .stream .Collectors ;
13
16
import org .slf4j .Logger ;
14
17
import org .slf4j .LoggerFactory ;
15
18
import org .springframework .beans .factory .annotation .Value ;
19
+ import org .springframework .http .MediaType ;
20
+ import org .springframework .web .bind .annotation .GetMapping ;
16
21
import org .springframework .web .bind .annotation .PostMapping ;
17
22
import org .springframework .web .bind .annotation .RequestMapping ;
18
23
import org .springframework .web .bind .annotation .RequestParam ;
19
24
import org .springframework .web .bind .annotation .RestController ;
25
+ import org .springframework .web .server .ServerWebExchange ;
26
+ import reactor .core .publisher .Flux ;
27
+ import reactor .core .publisher .Mono ;
20
28
21
29
@ RestController
22
30
@ RequestMapping ("/qa" )
@@ -38,54 +46,109 @@ public QAController(AiService aiService, VectorDBService vectorDBService) {
38
46
this .vectorDBService = vectorDBService ;
39
47
}
40
48
41
- @ PostMapping
42
- public Answer qa (@ RequestParam String question ) {
49
+ @ GetMapping ( produces = MediaType . TEXT_EVENT_STREAM_VALUE )
50
+ public Flux < Answer > qa (@ RequestParam String question ) {
43
51
question = question .trim ();
44
52
if (Strings .isNullOrEmpty (question )) {
45
- return Answer .EMPTY ;
53
+ return Flux . just ( Answer .EMPTY ) ;
46
54
}
47
55
48
56
try {
49
57
return doQA (question );
50
58
} catch (Throwable exception ) {
51
59
LOGGER .error ("Error while calling OpenAI API" , exception );
52
- return Answer .ERROR ;
60
+ return Flux . just ( Answer .ERROR ) ;
53
61
}
54
62
}
55
63
56
- private Answer doQA (String question ) {
57
- List <Embedding > embeddings = aiService .getEmbeddings (Lists .newArrayList (question ));
64
+ /**
65
+ * @deprecated Use {@link #qa(String)} instead.
66
+ */
67
+ @ Deprecated
68
+ @ PostMapping
69
+ public Mono <Answer > qaSync (ServerWebExchange serverWebExchange ) {
70
+ Mono <String > field = getFormField (serverWebExchange , "question" );
71
+ return field .flatMap (question -> {
72
+ if (Strings .isNullOrEmpty (question )) {
73
+ return Mono .just (Answer .EMPTY );
74
+ }
75
+
76
+ try {
77
+ Flux <Answer > answer = doQA (question .trim ());
78
+ return answer .reduce ((a1 , a2 ) -> {
79
+ if (Answer .END .answer ().equals (a2 .answer ())) {
80
+ return a1 ;
81
+ }
82
+ a1 .relatedFiles ().addAll (a2 .relatedFiles );
83
+
84
+ return new Answer (a1 .answer () + a2 .answer (), a1 .relatedFiles );
85
+ });
86
+ } catch (Throwable exception ) {
87
+ LOGGER .error ("Error while calling OpenAI API" , exception );
88
+ return Mono .just (Answer .ERROR );
89
+ }
90
+ });
91
+ }
58
92
59
- List <List <Float >> searchVectors = Collections .singletonList (
60
- embeddings .get (0 ).getEmbedding ().stream ()
61
- .map (Double ::floatValue ).collect (Collectors .toList ()));
93
+ private Mono <String > getFormField (ServerWebExchange exchange , String fieldName ) {
94
+ return exchange .getFormData ()
95
+ .flatMap (data -> Mono .justOrEmpty (data .getFirst (fieldName )));
96
+ }
62
97
63
- List <MarkdownSearchResult > searchResults = vectorDBService .search (searchVectors , topK );
98
+ private Flux <Answer > doQA (String question ) {
99
+ List <MarkdownSearchResult > searchResults = searchFromVectorDB (question );
64
100
65
101
if (searchResults .isEmpty ()) {
66
- return Answer .UNKNOWN ;
102
+ return Flux . just ( Answer .UNKNOWN ) ;
67
103
}
68
104
69
105
Set <String > relatedFiles = searchResults .stream ()
70
106
.map (MarkdownSearchResult ::getFileRoot ).collect (Collectors .toSet ());
71
107
72
- StringBuilder sb = new StringBuilder ();
73
- searchResults .forEach (
74
- markdownSearchResult -> sb .append (markdownSearchResult .getContent ()).append ("\n " ));
75
-
76
- String promptMessage = prompt .replace ("{question}" , question )
77
- .replace ("{context}" , sb .toString ());
108
+ String promptMessage = assemblePromptMessage (searchResults , question );
78
109
79
- String answer = aiService .getCompletion (promptMessage );
110
+ Flowable < ChatCompletionChunk > result = aiService .getCompletion (promptMessage );
80
111
81
112
if (LOGGER .isDebugEnabled ()) {
82
- LOGGER .debug ("\n Prompt message: {}\n Answer: {} " , promptMessage , answer );
113
+ LOGGER .debug ("\n Prompt message: {}" , promptMessage );
83
114
}
84
115
85
- return new Answer (answer , relatedFiles );
116
+ final AtomicInteger counter = new AtomicInteger ();
117
+ Flux <Answer > flux = Flux .from (result .filter (
118
+ chatCompletionChunk -> chatCompletionChunk .getChoices ().get (0 ).getMessage ().getContent ()
119
+ != null ).map (chatCompletionChunk -> {
120
+ String value = chatCompletionChunk .getChoices ().get (0 ).getMessage ().getContent ();
121
+ if (LOGGER .isDebugEnabled ()) {
122
+ System .out .print (value );
123
+ }
124
+
125
+ return counter .incrementAndGet () == 1 ? new Answer (value , relatedFiles )
126
+ : new Answer (value , Collections .emptySet ());
127
+ }));
128
+
129
+ return flux .concatWith (Flux .just (Answer .END ));
130
+ }
131
+
132
+ private List <MarkdownSearchResult > searchFromVectorDB (String question ) {
133
+ List <Embedding > embeddings = aiService .getEmbeddings (Lists .newArrayList (question ));
134
+
135
+ List <List <Float >> searchVectors = Collections .singletonList (
136
+ embeddings .get (0 ).getEmbedding ().stream ()
137
+ .map (Double ::floatValue ).collect (Collectors .toList ()));
138
+
139
+ return vectorDBService .search (searchVectors , topK );
86
140
}
87
141
88
- static class Answer {
142
+ private String assemblePromptMessage (List <MarkdownSearchResult > searchResults , String question ) {
143
+ StringBuilder sb = new StringBuilder ();
144
+ searchResults .forEach (
145
+ markdownSearchResult -> sb .append (markdownSearchResult .getContent ()).append ("\n " ));
146
+
147
+ return prompt .replace ("{question}" , question )
148
+ .replace ("{context}" , sb .toString ());
149
+ }
150
+
151
+ public record Answer (String answer , Set <String > relatedFiles ) {
89
152
90
153
static final Answer EMPTY = new Answer ("" , Collections .emptySet ());
91
154
static final Answer UNKNOWN = new Answer ("Sorry, I don't know the answer." ,
@@ -95,20 +158,6 @@ static class Answer {
95
158
"Sorry, I can't answer your question right now. Please try again later." ,
96
159
Collections .emptySet ());
97
160
98
- private final String answer ;
99
- private final Set <String > relatedFiles ;
100
-
101
- public Answer (String answer , Set <String > relatedFiles ) {
102
- this .answer = answer ;
103
- this .relatedFiles = relatedFiles ;
104
- }
105
-
106
- public String getAnswer () {
107
- return answer ;
108
- }
109
-
110
- public Set <String > getRelatedFiles () {
111
- return relatedFiles ;
112
- }
161
+ static final Answer END = new Answer ("$END$" , Collections .emptySet ());
113
162
}
114
163
}
0 commit comments