1
1
package com .apolloconfig .apollo .ai .qabot .milvus ;
2
2
3
- import com .google .common .collect .Lists ;
4
3
import com .apolloconfig .apollo .ai .qabot .api .VectorDBService ;
5
4
import com .apolloconfig .apollo .ai .qabot .config .MilvusConfig ;
6
5
import com .apolloconfig .apollo .ai .qabot .markdown .MarkdownSearchResult ;
6
+ import com .google .common .collect .Lists ;
7
7
import com .theokanning .openai .embedding .Embedding ;
8
8
import io .milvus .client .MilvusServiceClient ;
9
9
import io .milvus .common .clientenum .ConsistencyLevelEnum ;
32
32
import java .util .Arrays ;
33
33
import java .util .Collections ;
34
34
import java .util .List ;
35
+ import java .util .Random ;
35
36
import java .util .stream .Collectors ;
36
37
import org .springframework .context .annotation .Profile ;
37
38
import org .springframework .stereotype .Service ;
@@ -43,6 +44,7 @@ class MilvusService implements VectorDBService {
43
44
44
45
private final MilvusServiceClient milvusServiceClient ;
45
46
private final MilvusConfig milvusConfig ;
47
+ private final List <Float > dummyEmbeddings = Lists .newArrayList ();
46
48
47
49
public MilvusService (MilvusConfig milvusConfig ) {
48
50
this .milvusConfig = milvusConfig ;
@@ -160,7 +162,7 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
160
162
R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
161
163
loadCollectionParam );
162
164
163
- List <String > query_output_fields = Arrays . asList ("chunk_id" );
165
+ List <String > query_output_fields = List . of ("chunk_id" );
164
166
QueryParam queryParam = QueryParam .newBuilder ()
165
167
.withCollectionName (milvusConfig .getCollection ())
166
168
.withConsistencyLevel (ConsistencyLevelEnum .STRONG )
@@ -169,6 +171,10 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
169
171
.build ();
170
172
R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
171
173
174
+ if (respQuery .getStatus () != Status .Success .getCode ()) {
175
+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
176
+ }
177
+
172
178
QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
173
179
List <?> chunkIds = wrapperQuery .getFieldWrapper ("chunk_id" ).getFieldData ();
174
180
@@ -180,8 +186,116 @@ private List<Long> queryChunkIdByFileRoot(String fileRoot) {
180
186
.collect (Collectors .toList ());
181
187
}
182
188
189
+ @ Override
190
+ public String queryFileHashValue (String fileRoot ) {
191
+ LoadCollectionParam loadCollectionParam = LoadCollectionParam .newBuilder ()
192
+ .withCollectionName (milvusConfig .getFileCollection ())
193
+ .build ();
194
+
195
+ R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
196
+ loadCollectionParam );
197
+
198
+ List <String > query_output_fields = List .of ("hash_value" );
199
+ QueryParam queryParam = QueryParam .newBuilder ()
200
+ .withCollectionName (milvusConfig .getFileCollection ())
201
+ .withConsistencyLevel (ConsistencyLevelEnum .STRONG )
202
+ .withExpr (String .format ("file_root in ['%s']" , fileRoot ))
203
+ .withOutFields (query_output_fields )
204
+ .build ();
205
+ R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
206
+
207
+ if (respQuery .getStatus () != Status .Success .getCode ()) {
208
+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
209
+ }
210
+
211
+ QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
212
+ List <?> hashValues = wrapperQuery .getFieldWrapper ("hash_value" ).getFieldData ();
213
+
214
+ if (CollectionUtils .isEmpty (hashValues )) {
215
+ return null ;
216
+ }
217
+
218
+ return hashValues .get (0 ).toString ();
219
+ }
220
+
221
+ @ Override
222
+ public void persistFile (String fileRoot , String hashValue ) {
223
+ List <Long > currentFileIds = queryFileIdByFileRoot (fileRoot );
224
+
225
+ List <Field > fields = new ArrayList <>();
226
+ fields .add (new InsertParam .Field ("hash_value" , List .of (hashValue )));
227
+ fields .add (new InsertParam .Field ("dummy_embedding" , List .of (dummyEmbeddings )));
228
+ fields .add (new InsertParam .Field ("file_root" , List .of (fileRoot )));
229
+
230
+ InsertParam insertParam = InsertParam .newBuilder ()
231
+ .withCollectionName (milvusConfig .getFileCollection ())
232
+ .withFields (fields )
233
+ .build ();
234
+ milvusServiceClient .insert (insertParam );
235
+
236
+ deleteByFileIdList (currentFileIds );
237
+
238
+ FlushParam flushParam = FlushParam .newBuilder ()
239
+ .withCollectionNames (Lists .newArrayList (milvusConfig .getFileCollection ()))
240
+ .build ();
241
+ milvusServiceClient .flush (flushParam );
242
+ }
243
+
244
+ private void deleteByFileIdList (List <Long > fileIds ) {
245
+ if (!fileIds .isEmpty ()) {
246
+ StringBuilder sb = new StringBuilder ();
247
+ sb .append ("file_id in [" );
248
+ for (int i = 0 ; i < fileIds .size (); i ++) {
249
+ sb .append (fileIds .get (i ));
250
+ if (i != fileIds .size () - 1 ) {
251
+ sb .append ("," );
252
+ }
253
+ }
254
+ sb .append ("]" );
255
+ DeleteParam deleteParam = DeleteParam .newBuilder ()
256
+ .withCollectionName (milvusConfig .getFileCollection ())
257
+ .withExpr (sb .toString ())
258
+ .build ();
259
+ milvusServiceClient .delete (deleteParam );
260
+ }
261
+ }
262
+
263
+ private List <Long > queryFileIdByFileRoot (String fileRoot ) {
264
+ LoadCollectionParam loadCollectionParam = LoadCollectionParam .newBuilder ()
265
+ .withCollectionName (milvusConfig .getFileCollection ())
266
+ .build ();
267
+
268
+ R <RpcStatus > loadStatus = milvusServiceClient .loadCollection (
269
+ loadCollectionParam );
270
+
271
+ List <String > query_output_fields = List .of ("file_id" );
272
+ QueryParam queryParam = QueryParam .newBuilder ()
273
+ .withCollectionName (milvusConfig .getFileCollection ())
274
+ .withConsistencyLevel (ConsistencyLevelEnum .STRONG )
275
+ .withExpr (String .format ("file_root in ['%s']" , fileRoot ))
276
+ .withOutFields (query_output_fields )
277
+ .build ();
278
+ R <QueryResults > respQuery = milvusServiceClient .query (queryParam );
279
+
280
+ if (respQuery .getStatus () != Status .Success .getCode ()) {
281
+ throw new RuntimeException ("Query failed: " + respQuery .getMessage ());
282
+ }
283
+
284
+ QueryResultsWrapper wrapperQuery = new QueryResultsWrapper (respQuery .getData ());
285
+ List <?> fileIds = wrapperQuery .getFieldWrapper ("file_id" ).getFieldData ();
286
+
287
+ if (CollectionUtils .isEmpty (fileIds )) {
288
+ return Collections .emptyList ();
289
+ }
290
+
291
+ return fileIds .stream ().map (id -> Long .parseLong (id .toString ()))
292
+ .collect (Collectors .toList ());
293
+ }
294
+
295
+
183
296
private void ensureCollections () {
184
297
ensureChunkCollection ();
298
+ ensureFileCollection ();
185
299
}
186
300
187
301
private void ensureChunkCollection () {
@@ -239,5 +353,64 @@ private void ensureChunkCollection() {
239
353
240
354
}
241
355
356
+ private void ensureFileCollection () {
357
+ // prepare dummy embedding data
358
+ Random random = new Random ();
359
+ for (int i = 0 ; i < 1536 ; i ++) {
360
+ dummyEmbeddings .add (random .nextFloat ());
361
+ }
362
+
363
+ HasCollectionParam hasCollectionParam = HasCollectionParam .newBuilder ()
364
+ .withCollectionName (milvusConfig .getFileCollection ())
365
+ .build ();
366
+
367
+ if (milvusServiceClient .hasCollection (hasCollectionParam ).getData ()) {
368
+ return ;
369
+ }
370
+
371
+ FieldType fileId = FieldType .newBuilder ()
372
+ .withName ("file_id" )
373
+ .withDataType (DataType .Int64 )
374
+ .withPrimaryKey (true )
375
+ .withAutoID (true )
376
+ .build ();
377
+ FieldType fileRoot = FieldType .newBuilder ()
378
+ .withName ("file_root" )
379
+ .withDataType (DataType .VarChar )
380
+ .withMaxLength (100 )
381
+ .build ();
382
+ FieldType hashValue = FieldType .newBuilder ()
383
+ .withName ("hash_value" )
384
+ .withDataType (DataType .VarChar )
385
+ .withMaxLength (3000 )
386
+ .build ();
387
+ // not used, just for compatibility
388
+ FieldType dummyEmbedding = FieldType .newBuilder ()
389
+ .withName ("dummy_embedding" )
390
+ .withDataType (DataType .FloatVector )
391
+ .withDimension (1536 )
392
+ .build ();
393
+ CreateCollectionParam createCollectionReq = CreateCollectionParam .newBuilder ()
394
+ .withCollectionName (milvusConfig .getFileCollection ())
395
+ .withDescription ("Files for QA Search" )
396
+ .addFieldType (fileId )
397
+ .addFieldType (hashValue )
398
+ .addFieldType (fileRoot )
399
+ .addFieldType (dummyEmbedding )
400
+ .build ();
401
+
402
+ milvusServiceClient .createCollection (createCollectionReq );
403
+
404
+ // not used, just for compatibility
405
+ milvusServiceClient .createIndex (
406
+ CreateIndexParam .newBuilder ()
407
+ .withCollectionName (milvusConfig .getFileCollection ())
408
+ .withFieldName ("dummy_embedding" )
409
+ .withIndexType (IndexType .FLAT )
410
+ .withMetricType (MetricType .L2 )
411
+ .withSyncMode (Boolean .FALSE )
412
+ .build ()
413
+ );
414
+ }
242
415
243
416
}
0 commit comments