diff --git a/api/settings.py b/api/settings.py index 3148633e6e5..ac52cbe5cbf 100644 --- a/api/settings.py +++ b/api/settings.py @@ -23,6 +23,7 @@ import rag.utils.es_conn import rag.utils.infinity_conn import rag.utils.opensearch_conn +import rag.utils.baidu_vdb_conn from api.constants import RAG_FLOW_SERVICE_NAME from api.utils import decrypt_database_config, get_base_config from api.utils.file_utils import get_project_base_directory @@ -184,10 +185,12 @@ def init_settings(): docStoreConn = rag.utils.infinity_conn.InfinityConnection() elif lower_case_doc_engine == "opensearch": docStoreConn = rag.utils.opensearch_conn.OSConnection() + elif lower_case_doc_engine == "baiduvdb": + docStoreConn = rag.utils.baidu_vdb_conn.BaiduVDBConnection() else: raise Exception(f"Not supported doc engine: {DOC_ENGINE}") - retrievaler = search.Dealer(docStoreConn) + retrievaler = search.Dealer(dataStore = docStoreConn, docEngine=DOC_ENGINE) from graphrag import search as kg_search kg_retrievaler = kg_search.KGSearch(docStoreConn) diff --git a/conf/mochow_mapping.json b/conf/mochow_mapping.json new file mode 100644 index 00000000000..6d5abac68f2 --- /dev/null +++ b/conf/mochow_mapping.json @@ -0,0 +1,42 @@ +{ + "id": {"fieldType": "STRING", "primaryKey": true, "partitionKey": true, "default": ""}, + "doc_id": {"fieldType": "STRING", "default": ""}, + "kb_id": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "create_time": {"fieldType": "STRING", "default": ""}, + "create_timestamp_flt": {"fieldType": "FLOAT", "default": 0.0}, + "img_id": {"fieldType": "STRING", "default": ""}, + "docnm_kwd": {"fieldType": "STRING", "default": ""}, + "title_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "title_sm_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "name_kwd": {"fieldType": "STRING", "default": ""}, + "important_kwd": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "tag_kwd": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "important_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "question_kwd": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "question_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "content_with_weight": {"fieldType": "STRING", "default": ""}, + "content_ltks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "content_sm_ltks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "authors_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "authors_sm_tks": {"fieldType": "TEXT", "default": "", "analyzer": "DEFAULT_ANALYZER"}, + "page_num_int": {"fieldType": "ARRAY", "elementType": "INT64", "default": [0]}, + "top_int": {"fieldType": "ARRAY", "elementType": "INT64", "default": [0]}, + "position_int": {"fieldType": "ARRAY", "elementType": "INT64", "default": [0]}, + "weight_int": {"fieldType": "INT64", "default": 0}, + "weight_flt": {"fieldType": "FLOAT", "default": 0.0}, + "rank_int": {"fieldType": "INT64", "default": 0}, + "rank_flt": {"fieldType": "FLOAT", "default": 0}, + "available_int": {"fieldType": "INT64", "default": 1}, + "knowledge_graph_kwd": {"fieldType": "STRING", "default": ""}, + "entities_kwd": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "pagerank_fea": {"fieldType": "INT64", "default": 0}, + "tag_feas": {"fieldType": "STRING", "default": ""}, + + "from_entity_kwd": {"fieldType": "STRING", "default": ""}, + "to_entity_kwd": {"fieldType": "STRING", "default": ""}, + "entity_kwd": {"fieldType": "STRING", "default": ""}, + "entity_type_kwd": {"fieldType": "STRING", "default": ""}, + "source_id": {"fieldType": "ARRAY", "elementType": "STRING", "default": [""]}, + "n_hop_with_weight": {"fieldType": "STRING", "default": ""}, + "removed_kwd": {"fieldType": "STRING", "default": ""} +} diff --git a/conf/service_conf.yaml b/conf/service_conf.yaml index 7b76f2b4f13..dd1a91b2840 100644 --- a/conf/service_conf.yaml +++ b/conf/service_conf.yaml @@ -29,6 +29,11 @@ redis: db: 1 password: 'infini_rag_flow' host: 'localhost:6379' +baiduvdb: + endpoint: 'http://127.0.0.:5287' + username: 'root' + password: 'mochow' + replication: 1 # postgres: # name: 'rag_flow' # user: 'rag_flow' diff --git a/docker/.env b/docker/.env index 1375f5564eb..91a33b6a9c1 100644 --- a/docker/.env +++ b/docker/.env @@ -53,7 +53,13 @@ INFINITY_THRIFT_PORT=23817 INFINITY_HTTP_PORT=23820 INFINITY_PSQL_PORT=5432 -# The password for MySQL. +# config of baidu vdb +BAIDU_VDB_ENDPOINT=http://127.0.0.1:5287 +BAIDU_VDB_USER=root +BAIDU_VDB_PASSWORD=password +BAIDU_VDB_REPLICATION=1 + +# The password for MySQL. MYSQL_PASSWORD=infini_rag_flow # The hostname where the MySQL service is exposed MYSQL_HOST=mysql diff --git a/docker/service_conf.yaml.template b/docker/service_conf.yaml.template index 5db35b9c7f2..04017bf5168 100644 --- a/docker/service_conf.yaml.template +++ b/docker/service_conf.yaml.template @@ -29,6 +29,11 @@ redis: db: 1 password: '${REDIS_PASSWORD:-infini_rag_flow}' host: '${REDIS_HOST:-redis}:6379' +baiduvdb: + endpoint: 'http://${BAIDU_VDB_ENDPOINT:-http://127.0.0.1:5287}' + username: '${BAIDU_VDB_USER:-root}' + password: '${BAIDU_VDB_PASSWORD:-password}' + replication: '${BAIDU_VDB_REPLICATION:-1}' # postgres: # name: '${POSTGRES_DBNAME:-rag_flow}' # user: '${POSTGRES_USER:-rag_flow}' diff --git a/pyproject.toml b/pyproject.toml index c82956d777d..1d1b04c5dc6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ dependencies = [ "debugpy>=1.8.13", "mcp>=1.9.4", "opensearch-py==2.7.1", + "pymochow==2.2.9" "pluginlib==0.9.4", "click>=8.1.8", "python-calamine>=0.4.0", diff --git a/rag/nlp/query.py b/rag/nlp/query.py index b708ff490a0..54bed8307e8 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -24,7 +24,7 @@ class FulltextQueryer: - def __init__(self): + def __init__(self, docEngine: str = "elasticsearch"): self.tw = term_weight.Dealer() self.syn = synonym.Dealer() self.query_fields = [ @@ -36,6 +36,15 @@ def __init__(self): "content_ltks^2", "content_sm_ltks", ] + self.baidu_vdb_query_fields = [ + "title_tks", + "title_sm_tks", + "important_tks", + "question_tks", + "content_ltks", + "content_sm_ltks", + ] + self.docEngine = docEngine @staticmethod def subSpecialChar(line): @@ -83,6 +92,8 @@ def add_space_between_eng_zh(txt): return txt def question(self, txt, tbl="qa", min_match: float = 0.6): + if self.docEngine == "baiduvdb": + return self.question_by_baidu_vdb(txt=txt) txt = FulltextQueryer.add_space_between_eng_zh(txt) txt = re.sub( r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", @@ -216,6 +227,84 @@ def need_fine_grained_tokenize(tk): ), keywords return None, keywords + def question_by_baidu_vdb(self, txt): + txt = re.sub( + r"[ :|\r\n\t,,。??/`!!&^%%()\[\]{}<>]+", + " ", + rag_tokenizer.tradi2simp(rag_tokenizer.strQ2B(txt.lower())), + ).strip() + txt = FulltextQueryer.rmWWW(txt) + + if not self.isChinese(txt): + txt = FulltextQueryer.rmWWW(txt) + tks = rag_tokenizer.tokenize(txt).split() + keywords = [t for t in tks if t] + tks_w = self.tw.weights(tks, preprocess=False) + tks_w = [(re.sub(r"[ \\\"'^]", "", tk), w) for tk, w in tks_w] + tks_w = [(re.sub(r"^[a-z0-9]$", "", tk), w) for tk, w in tks_w if tk] + tks_w = [(re.sub(r"^[\+-]", "", tk), w) for tk, w in tks_w if tk] + tks_w = [(tk.strip(), w) for tk, w in tks_w if tk.strip()] + for tk, w in tks_w[:256]: + syn = self.syn.lookup(tk) + syn = rag_tokenizer.tokenize(" ".join(syn)).split() + keywords.extend(syn) + + return MatchTextExpr( + self.baidu_vdb_query_fields, txt, 100 + ), keywords + + def need_fine_grained_tokenize(tk): + if len(tk) < 3: + return False + if re.match(r"[0-9a-z\.\+#_\*-]+$", tk): + return False + return True + + txt = FulltextQueryer.rmWWW(txt) + keywords = [] + for tt in self.tw.split(txt)[:256]: # .split(): + if not tt: + continue + keywords.append(tt) + twts = self.tw.weights([tt]) + syns = self.syn.lookup(tt) + if syns and len(keywords) < 32: + keywords.extend(syns) + logging.debug(json.dumps(twts, ensure_ascii=False)) + for tk, w in sorted(twts, key=lambda x: x[1] * -1): + sm = ( + rag_tokenizer.fine_grained_tokenize(tk).split() + if need_fine_grained_tokenize(tk) + else [] + ) + sm = [ + re.sub( + r"[ ,\./;'\[\]\\`~!@#$%\^&\*\(\)=\+_<>\?:\"\{\}\|,。;‘’【】、!¥……()——《》?:“”-]+", + "", + m, + ) + for m in sm + ] + sm = [FulltextQueryer.subSpecialChar(m) for m in sm if len(m) > 1] + sm = [m for m in sm if len(m) > 1] + + if len(keywords) < 32: + keywords.append(re.sub(r"[ \\\"']+", "", tk)) + keywords.extend(sm) + + tk_syns = self.syn.lookup(tk) + tk_syns = [FulltextQueryer.subSpecialChar(s) for s in tk_syns] + if len(keywords) < 32: + keywords.extend([s for s in tk_syns if s]) + tk_syns = [rag_tokenizer.fine_grained_tokenize(s) for s in tk_syns if s] + tk_syns = [f"\"{s}\"" if s.find(" ") > 0 else s for s in tk_syns] + + if len(keywords) >= 32: + break + return MatchTextExpr( + self.baidu_vdb_query_fields, txt, 100 + ), keywords + def hybrid_similarity(self, avec, bvecs, atks, btkss, tkweight=0.3, vtweight=0.7): from sklearn.metrics.pairwise import cosine_similarity as CosineSimilarity import numpy as np diff --git a/rag/nlp/search.py b/rag/nlp/search.py index b1617b9a7b5..a0216679792 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -30,9 +30,10 @@ def index_name(uid): return f"ragflow_{uid}" class Dealer: - def __init__(self, dataStore: DocStoreConnection): - self.qryr = query.FulltextQueryer() + def __init__(self, dataStore: DocStoreConnection, docEngine: str = "elasticsearch"): + self.qryr = query.FulltextQueryer(docEngine=docEngine) self.dataStore = dataStore + self.docEngine = docEngine @dataclass class SearchResult: diff --git a/rag/settings.py b/rag/settings.py index 70d1b6234cc..0be1e277bf3 100644 --- a/rag/settings.py +++ b/rag/settings.py @@ -27,6 +27,7 @@ ES = {} INFINITY = {} +BAIDUVDB = {} AZURE = {} S3 = {} MINIO = {} @@ -40,6 +41,8 @@ OS = get_base_config("os", {}) elif DOC_ENGINE == 'infinity': INFINITY = get_base_config("infinity", {"uri": "infinity:23817"}) +elif DOC_ENGINE == 'baiduvdb': + BAIDUVDB = get_base_config("baiduvdb", {}) if STORAGE_IMPL_TYPE in ['AZURE_SPN', 'AZURE_SAS']: AZURE = get_base_config("azure", {}) diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 84c73d2b695..ce7f5e51e47 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -648,7 +648,7 @@ async def delete_image(kb_id, chunk_id): if b % 128 == 0: progress_callback(prog=0.8 + 0.1 * (b + 1) / len(chunks), msg="") if doc_store_result: - error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity status!" + error_message = f"Insert chunk error: {doc_store_result}, please check log file and Elasticsearch/Infinity/BaiduVDB status!" progress_callback(-1, msg=error_message) raise Exception(error_message) chunk_ids = [chunk["id"] for chunk in chunks[:b + DOC_BULK_SIZE]] diff --git a/rag/utils/baidu_vdb_conn.py b/rag/utils/baidu_vdb_conn.py new file mode 100644 index 00000000000..031e97fd3cc --- /dev/null +++ b/rag/utils/baidu_vdb_conn.py @@ -0,0 +1,653 @@ +# +# Copyright 2025 The Baidu Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging +import re +import json +import os + +import copy + +import pymochow +from pymochow.configuration import Configuration +from pymochow.auth.bce_credentials import BceCredentials +from pymochow.model.schema import ( + Schema, Field, IndexField, + VectorIndex, + FilteringIndex, IndexStructureType, + InvertedIndex, InvertedIndexAnalyzer, InvertedIndexParams, InvertedIndexFieldAttribute, InvertedIndexParseMode, + HNSWParams, + AutoBuildRowCountIncrement +) +from pymochow.model.enum import ( + FieldType, IndexType, MetricType, ElementType, + ServerErrCode +) +from pymochow.model.database import Database +from pymochow.model.table import ( + Table, Partition, Row, + VectorTopkSearchRequest, BM25SearchRequest, HybridSearchRequest, + VectorSearchConfig +) +from pymochow.exception import ClientError, ServerError + +from rag import settings +from rag.utils import singleton +import pandas as pd +from api.utils.file_utils import get_project_base_directory +from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, MatchTextExpr, MatchDenseExpr, \ + FusionExpr + +ATTEMPT_TIME = 2 +logger = logging.getLogger('ragflow.baidu_vdb_conn') + + +filter_idx_filed_pattern = re.compile(r"^(.*_(kwd|id|ids|uid|uids)|uid)$") +invertes_idx_filedpattern = re.compile(r".*_(tks|ltks)$") + +@singleton +class BaiduVDBConnection(DocStoreConnection): + def __init__(self): + self.db_name = "ragflow" + self.tks_inverted_idx = "tks_inverted_idx" + self.kwd_filter_idx = "kwd_filter_idx" + logger.info(f"User BaiduVDB {settings.BAIDUVDB['endpoint']} as the doc engine") + config = Configuration(credentials=BceCredentials(settings.BAIDUVDB['username'], settings.BAIDUVDB['password']), + endpoint=settings.BAIDUVDB['endpoint']) + self.client = pymochow.MochowClient(config) + fp_mapping = os.path.join( + get_project_base_directory(), "conf", "mochow_mapping.json" + ) + if not os.path.exists(fp_mapping): + raise Exception(f"Mapping file not found at {fp_mapping}") + self.mapping = json.load(open(fp_mapping)) + healthy = self.health() + self.query_fields_boosts = { + "title_tks": 10, + "title_sm_tks": 5, + "important_tks": 20, + "question_tks": 20, + "content_ltks": 2, + "content_sm_ltks": 1, + } + if healthy["err"] == "": + logger.info(f"BaiduVDB {settings.BAIDUVDB['endpoint']} is healthy.") + else: + logger.warning(f"BaiduVDB {settings.BAIDUVDB['endpoint']} is not healthy. error: {healthy.get('err')}") + + + """ + Database operations + """ + + def dbType(self) -> str: + return "BaiduVDB" + + def health(self) -> dict: + res = { + "type": "BaiduVDB" + } + try: + self.client.list_databases() + res["status"] = "normal" + res["err"] = "" + except Exception as e: + res["status"] = "invalid" + res["err"] = str(e) + return res + + """ + Table operations + """ + + def _get_table_schema(self) -> Schema: + fields: list[Field] = [] + indexes: list[IndexField] = [] + + for field, field_attribute in self.mapping.items(): + assert isinstance(field, str) + assert isinstance(field_attribute, dict) + if field == "id": + fields.append(Field(field, FieldType.STRING, primary_key=True, partition_key=True, not_null=True)) + elif field_attribute["fieldType"] == FieldType.ARRAY.value: + fields.append(Field(field_name=field, + field_type=FieldType(field_attribute["fieldType"]), + element_type=ElementType(field_attribute["elementType"]))) + else: + fields.append(Field(field_name=field, field_type=FieldType(field_attribute["fieldType"]))) + + dimension_list: list[int] = [512, 768, 1024, 1536] + + for _, dimension in enumerate(dimension_list): + fields.append(Field( + field_name=f"q_{dimension}_vec", + field_type=FieldType.FLOAT_VECTOR, + dimension=dimension, + not_null=False, + )) + + + vector_index = VectorIndex( + index_name=f"q_{dimension}_vec_idx", + index_type=IndexType.HNSW, + field=f"q_{dimension}_vec", + metric_type=MetricType.COSINE, + params=HNSWParams(m=16,efconstruction=50), + auto_build=True, + auto_build_index_policy=AutoBuildRowCountIncrement(row_count_increment=10000, row_count_increment_ratio=0.2), + ) + + indexes.append(vector_index) + + filter_index_fileds: list[str] = [field.field_name for field in fields if filter_idx_filed_pattern.match(field.field_name)] + indexes.append(FilteringIndex( + index_name=self.kwd_filter_idx, + fields=[{"field": field_name, "indexStructureType": IndexStructureType.BITMAP} for field_name in filter_index_fileds] + )) + + invertes_index_fields: list[str] = [field.field_name for field in fields if invertes_idx_filedpattern.match(field.field_name)] + indexes.append(InvertedIndex( + index_name=self.tks_inverted_idx, + fields=invertes_index_fields, + params=InvertedIndexParams( + analyzer=InvertedIndexAnalyzer.DEFAULT_ANALYZER, + parse_mode=InvertedIndexParseMode.COARSE_MODE, + ), + field_attributes=[InvertedIndexFieldAttribute.ANALYZED] * len(invertes_index_fields), + )) + + schema = Schema(fields=fields, indexes=indexes) + logger.debug(f"create table schema: {str(schema.to_dict())}") + return schema + + def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int): + if self.indexExist(indexName=indexName, knowledgebaseId=knowledgebaseId): + return + db_list: list[Database] = self.client.list_databases() + db_name_list: list[str] = [db.database_name for db in db_list] + has_db = self.db_name in db_name_list + if not has_db: + self.client.create_database(self.db_name) + + table_name = indexName + db = self.client.database(self.db_name) + try: + db.create_table( + table_name=table_name, + replication=settings.BAIDUVDB['replication'], + partition=Partition(partition_num=3), + schema=self._get_table_schema() + ) + except Exception as e: + logger.warning(f"BaiduVDB create index {indexName} failed, error: {str(e)}") + logger.info(f"BaiduVDB create index {indexName} succeed") + + + def deleteIdx(self, indexName: str, knowledgebaseId: str): + if len(knowledgebaseId) > 0: + # The index need to be alive after any kb deletion since all kb under this tenant are in one index. + return + try: + db = self.client.database(self.db_name) + table_name = indexName + db.drop_table(table_name=table_name) + except ServerError as e: + if e.code == ServerErrCode.DB_NOT_EXIST: + return + else: + logger.warning(f"BaiduVDB deleteIdx {str(e)}") + except Exception as e: + logger.warning(f"BaiduVDB deleteIdx {str(e)}") + + + def indexExist(self, indexName: str, knowledgebaseId: str) -> bool: + try: + db = self.client.database(self.db_name) + table_name = indexName + table_list: list[Table] = db.list_table() + table_name_list = [table.table_name for table in table_list] + return table_name in table_name_list + except ClientError: + return False + except Exception as e: + logger.warning(f"BaiduVDB indexExist {str(e)}") + return False + return True + + """ + CRUD operations + """ + + def search( + self, + selectFields: list[str], + highlightFields: list[str], + condition: dict, + matchExprs: list[MatchExpr], + orderBy: OrderByExpr, + offset: int, + limit: int, + indexNames: str | list[str], + knowledgebaseIds: list[str], + aggFields: list[str] = ..., + rank_feature: dict | None = None + ) -> tuple[pd.DataFrame, int]: + """ + TODO: vdb not support highlight, agg, rank + """ + if isinstance(indexNames, str): + indexNames = indexNames.split(",") + assert isinstance(indexNames, list) and len(indexNames) > 0 + projections = selectFields.copy() + if len(projections) != 0: + for essential_field in ["id"]: + if essential_field not in projections: + projections.append(essential_field) + + assert "_id" not in condition + condition["kb_id"] = knowledgebaseIds + filter = self._condition_to_filter(condition=condition) + db = self.client.database(self.db_name) + + res = list() + search_req = None + if len(matchExprs) == 0: + # select + select_res = self._select(db, indexNames=indexNames, projections=projections, filter=filter if filter != "" else None) + res.extend(select_res) + elif len(matchExprs) and isinstance(matchExprs[0], MatchTextExpr): + # bm25_search + search_req = self._matchTextExpr2Bm25SearchReq(matchExprs[0], filter if filter != "" else None) + elif len(matchExprs) and isinstance(matchExprs[0], MatchDenseExpr): + # vector_search + search_req = self._matchDenseExpr2VectorSearchReq(matchExprs[0], filter if filter != "" else None) + else: + # hybird_search + is_bybird_search = False + vector_weight = 0.5 + for m in matchExprs: + if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params: + assert len(matchExprs) == 3 \ + and isinstance(matchExprs[0], MatchTextExpr) \ + and isinstance(matchExprs[1], MatchDenseExpr) \ + and isinstance(matchExprs[2], FusionExpr) + weights = m.fusion_params["weights"] + vector_weight = float(weights.split(",")[1]) + is_bybird_search = True + if is_bybird_search: + bm25_search_req = self._matchTextExpr2Bm25SearchReq(matchExprs[0], None) + vector_search_req = self._matchDenseExpr2VectorSearchReq(matchExprs[1], None) + search_req = HybridSearchRequest( + vector_request=vector_search_req, + bm25_request=bm25_search_req, + vector_weight=vector_weight, + bm25_weight=1-vector_weight, + filter=filter if filter != "" else None, + limit=matchExprs[1].topn, + ) + + if search_req is not None: + search_res = self._Search(db=db, indexNames=indexNames, req=search_req, projections=projections) + res.extend(search_res) + + if len(orderBy.fields) > 0: + sort_fields = list() + for order_field in orderBy.fields: + field, desc = order_field[0], order_field[1] + if field in ["page_num_int","top_int"]: + sort_fields.append((field, min, bool(desc))) + if field == "create_timestamp_flt": + def f(x): + return -x if desc else x + sort_fields.append((field, f, bool(desc))) + def build_sort_key(entry): + return [func(entry[field]) if callable(func) else entry[field] for field, func, _ in sort_fields] + + if len(sort_fields) > 0: + res = sorted(res, key=build_sort_key) + + if limit > 0: + res = res[offset:offset+limit] + return res + + def _select(self, db: Database, indexNames: list[str], projections: list[str], filter: str): + res_list = list() + for table_name in indexNames: + table = db.table(table_name=table_name) + marker = None + while True: + res = table.select(marker=marker, filter=filter, projections=projections, limit=50) + assert isinstance(res.rows, list) + res_list.extend(res.rows) + if res.is_truncated is False: + break + else: + marker = res.next_marker + return res_list + + def _matchTextExpr2Bm25SearchReq(self, m: MatchTextExpr, filter: str) -> BM25SearchRequest: + search_text_cond = list() + for field in m.fields: + boost = 1 + if field in self.query_fields_boosts: + boost = self.query_fields_boosts[field] + if boost != 1: + search_text_cond.append(f"{field}:{m.matching_text}^{boost}") + else: + search_text_cond.append(f"{field}:{m.matching_text}") + search_text = " OR ".join(search_text_cond) + return BM25SearchRequest( + index_name=self.tks_inverted_idx, + search_text=search_text, + filter=filter, + limit=m.topn, + ) + + def _matchDenseExpr2VectorSearchReq(self, m: MatchDenseExpr, filter: str) -> VectorTopkSearchRequest: + config = VectorSearchConfig(ef=m.topn*2) + return VectorTopkSearchRequest( + vector_field=m.vector_column_name, + limit=m.topn, + vector=list(m.embedding_data), + filter=filter, + config=config + ) + + def _Search(self, db: Database, indexNames: list[str], req: BM25SearchRequest|VectorTopkSearchRequest|HybridSearchRequest, projections: list[str]): + res_list = list() + for table_name in indexNames: + table = db.table(table_name=table_name) + res = None + if (isinstance(req, BM25SearchRequest)): + res = table.bm25_search(request=req, projections=projections) + elif (isinstance(req, VectorTopkSearchRequest)): + res = table.vector_search(request=req, projections=projections) + elif (isinstance(req, HybridSearchRequest)): + res = table.hybrid_search(request=req, projections=projections) + if res is None: + continue + assert isinstance(res.rows, list) + for row in res.rows: + res_list.append(row['row']) + return res_list + + def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None: + db = self.client.database(self.db_name) + table_name = indexName + table_instance = db.table(table_name=table_name) + kb_res = table_instance.query(primary_key={'id': chunkId}) + return kb_res.row + + def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]: + # Refers https://cloud.baidu.com/doc/VDB/s/8lrsob128#%E6%9B%B4%E6%96%B0%E6%8F%92%E5%85%A5%E8%AE%B0%E5%BD%95 + docs = copy.deepcopy(documents) + for d in docs: + assert "id" in d + for k, v in d.items(): + if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: + assert isinstance(v, list) + elif re.search(r"_feas$", k): + d[k] = json.dumps(v) + elif k == 'kb_id': + if isinstance(v, str): + d[k] = [v] + elif k == "position_int": + assert isinstance(v, list) + d[k] = [num for row in v for num in row] + if 'kb_id' not in d: + d["kb_id"] = [knowledgebaseId] + # set default value for scalar data + for field, field_attribute in self.mapping.items(): + assert isinstance(field, str) + assert isinstance(field_attribute, dict) + if field not in d: + d[field] = field_attribute["default"] + + db = self.client.database(self.db_name) + table_name = indexName + table_instance = db.table(table_name=table_name) + chunk_rows_list = [docs[i:i+300] for i in range(0, len(docs), 300)] + + res = [] + for chunk_rows in chunk_rows_list: + for _ in range(ATTEMPT_TIME): + try: + table_instance.upsert(rows=[Row(**c_row) for c_row in chunk_rows]) + except Exception as e: + res.append(str(e)) + logger.warning("BaiduVDB.upsert got error: " + str(e)) + continue + return res + + def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool: + db = self.client.database(self.db_name) + table_name = indexName + table = db.table(table_name=table_name) + doc = copy.deepcopy(newValue) + doc.pop("id", None) + for k, v in doc.items(): + if k in ["important_kwd", "question_kwd", "entities_kwd", "tag_kwd", "source_id"]: + assert isinstance(v, list) + elif re.search(r"_feas$", k): + doc[k] = json.dumps(v) + elif k == "position_int": + assert isinstance(v, list) + doc[k] = [num for row in v for num in row] + elif k == 'kb_id': + if isinstance(v, str): + doc[k] = [v] + elif k == "remove": + del doc[k] + # replace value with default value + assert k in self.mapping + k_attr = self.mapping[k] + assert isinstance(k_attr, dict) + doc[k] = k_attr["default"] + if "id" in condition and isinstance(condition["id"], str): + chunkId = condition["id"] + try: + table.update( + primary_key={"id": chunkId}, + update_fields=doc, + ) + except Exception as e: + logging.warning(f"BaiduVDB update row from table by primary_key: id: {chunkId} failed, error: {str(e)}") + return False + return True + condition["kb_id"] = [knowledgebaseId] + filter = self._condition_to_filter(condition=condition) + logger.debug(f"BaiduVDB update row from table {table_name} start") + marker = None + projection = ["id"] + all_chunkIds = [] + # get ids by filter + try: + while True: + res = table.select(marker=marker, projections=projection, filter=filter, limit=50) + assert isinstance(res.rows, list) + for row in res.rows: + all_chunkIds.append(row['id']) + if res.is_truncated is False: + break + else: + marker = res.next_marker + except Exception as e: + logger.warning(f"BaiduVDB: Fail to update row from table {table_name}, filter: {filter}, due to get ids by select failed: error: {str(e)}") + return False + # update row by primary_key + try: + for chunkId in all_chunkIds: + table.update( + primary_key={"id": chunkId}, + update_fields=doc, + ) + except Exception as e: + logging.warning(f"BaiduVDB update row from table by primiary_ley got by filter: {filter} failed, error: {str(e)}") + return False + logger.debug(f"BaiduVDB update row from table {table_name}, filter: {filter}") + return True + + def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int: + if not self.indexExist(indexName, knowledgebaseId): + return 0 + condition["kb_id"] = [knowledgebaseId] + filter = self._condition_to_filter(condition=condition) + db = self.client.database(self.db_name) + table_name = indexName + table = db.table(table_name=table_name) + try: + table.delete(filter=filter) + except Exception as e: + logger.warning(f"BaiduVDB delete row from table {table_name} failed, filter: {filter}, error: {str(e)}") + return 0 + logger.debug(f"BaiduVDB delete row from table {table_name}, filter: {filter}") + return 0 + + def _condition_exisist(self, field: str) ->str: + assert field in self.mapping + field_attribute = self.mapping[field] + assert field_attribute["fieldType"] != FieldType.TEXT + assert isinstance(field_attribute, dict) + field_default = field_attribute["default"] + if field_attribute["fieldType"] == FieldType.ARRAY.value: + assert isinstance(field_default, list) + field_default_first = field_default[0] + if isinstance(field_default_first, str): + return f"{field}[0] != '{field_default_first}'" + else: + return f"{field}[0] != {field_default}" + else: + if isinstance(field_default, str): + return f"{field} != '{field_default}'" + else: + return f"{field} != {field_default}" + + def _condition_to_filter(self, condition: dict) -> str: + if "id" in condition: + id_value = condition["id"] + if isinstance(id_value, str): + return f"id == '{id_value}'" + elif isinstance(id_value, list): + inCond = [f'{item}' for item in id_value] + strInCond = ", ".join(inCond) + return f"id IN ({strInCond})" + cond = list() + for k, v in condition.items(): + if not isinstance(k, str) or not v: + continue + if k == "must_not": + if isinstance(v, dict): + for kk, vv in v.items(): + if kk == "exists": + cond.append("NOT (%s)" % self._condition_exisist(vv)) + continue + elif k == "exists": + cond.append(self._condition_exisist(v)) + continue + + assert k in self.mapping + field_attribute = self.mapping[k] + assert field_attribute["fieldType"] != FieldType.TEXT + assert isinstance(field_attribute, dict) + if isinstance(v, list): + inCond = list() + for item in v: + if isinstance(item, str): + inCond.append(f"'{item}'") + else: + inCond.append(str(item)) + if inCond: + strInCond = ", ".join(inCond) + if field_attribute["fieldType"] == FieldType.ARRAY.value: + cond.append(f"array_contains_any({k}, [{strInCond}])") + else: + cond.append( f"{k} IN ({strInCond})") + elif isinstance(v, str): + if field_attribute["fieldType"] == FieldType.ARRAY.value: + cond.append(f"array_contains({k}, '{v}')") + else: + cond.append(f"{k} == '{v}'") + else: + if field_attribute["fieldType"] == FieldType.ARRAY.value: + cond.append(f"array_contains({k}, {str(v)})") + else: + cond.append(f"{k} == {str(v)}") + + if len(cond) == 0: + return "" + return " AND ".join(cond) + + + """ + Helper functions for search result + """ + + def getTotal(self, rows): + return len(rows) + + def getChunkIds(self, rows): + return [row["id"] for row in rows] + + def getFields(self, rows, fields: list[str]) -> dict[str, dict]: + res_fields = {} + for row in rows: + m = {n: row.get(n) for n in fields if row.get(n) is not None} + if "position_int" in m and isinstance(m["position_int"], list): + m["position_int"] = [m["position_int"][i: i+5] for i in range(0, len(m["position_int"]), 5)] + if m: + res_fields[row["id"]] = m + return res_fields + + def getHighlight(self, rows, keywords: list[str], fieldnm: str): + ans = {} + for row in rows: + txt = row[fieldnm] + txt = re.sub(r"[\r\n]", " ", txt, flags=re.IGNORECASE | re.MULTILINE) + txts = [] + for t in re.split(r"[.?!;\n]", txt): + for w in keywords: + t = re.sub( + r"(^|[ .?/'\"\(\)!,:;-])(%s)([ .?/'\"\(\)!,:;-])" + % re.escape(w), + r"\1\2\3", + t, + flags=re.IGNORECASE | re.MULTILINE, + ) + if not re.search( + r"[^<>]+", t, flags=re.IGNORECASE | re.MULTILINE + ): + continue + txts.append(t) + ans[row["id"]] = "...".join(txts) + return ans + + def getAggregation(self, res, fieldnm: str): + return list() + + """ + SQL + """ + def sql(sql: str, fetch_size: int, format: str): + """ + Run the sql generated by text-to-sql + """ + raise NotImplementedError("BaiduVDB not support sql") + + + + + + diff --git a/uv.lock b/uv.lock index f9d903a47aa..47b45bb29fe 100644 --- a/uv.lock +++ b/uv.lock @@ -4863,6 +4863,20 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pymochow" +version = "2.2.4" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "future" }, + { name = "orjson" }, + { name = "requests" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/80/38/0893f0940a9b1dfd741805401b986d896816e25d9e72f67a3d36b6da228d/pymochow-2.2.4.tar.gz", hash = "sha256:0c9dce6409012ab6e8b705e071e60fc2a8d065acf2b8a18de2ee3f0771ae27f6" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/67/99/6c1bab3bfe505946d4b63473010cf1318865beabde45aa685861a4fbddd6/pymochow-2.2.4-py3-none-any.whl", hash = "sha256:6ada2e4984510a409d9acdecaf4691086e7248067599eab1f84566bc474775a7" }, +] + [[package]] name = "pymysql" version = "1.1.1" @@ -5349,6 +5363,7 @@ dependencies = [ { name = "pyclipper" }, { name = "pycryptodomex" }, { name = "pyicu" }, + { name = "pymochow" }, { name = "pymysql" }, { name = "pyodbc" }, { name = "pypdf" }, @@ -5504,6 +5519,7 @@ requires-dist = [ { name = "pyclipper", specifier = "==1.3.0.post5" }, { name = "pycryptodomex", specifier = "==3.20.0" }, { name = "pyicu", specifier = ">=2.13.1,<3.0.0" }, + { name = "pymochow", specifier = "==2.2.4" }, { name = "pymysql", specifier = ">=1.1.1,<2.0.0" }, { name = "pyodbc", specifier = ">=5.2.0,<6.0.0" }, { name = "pypdf", specifier = "==6.0.0" }, diff --git a/web/src/components/llm-tools-select.tsx b/web/src/components/llm-tools-select.tsx index 241dfa374a4..56c3dc58e6c 100644 --- a/web/src/components/llm-tools-select.tsx +++ b/web/src/components/llm-tools-select.tsx @@ -9,7 +9,7 @@ interface IProps { } const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => { - const { t } = useTranslate("llmTools"); + const { t } = useTranslate('llmTools'); const tools = useLlmToolsList(); function wrapTranslation(text: string): string { @@ -17,14 +17,14 @@ const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => { return text; } - if (text.startsWith("$t:")) { + if (text.startsWith('$t:')) { return t(text.substring(3)); } return text; } - const toolOptions = tools.map(t => ({ + const toolOptions = tools.map((t) => ({ label: wrapTranslation(t.displayName), description: wrapTranslation(t.displayDescription), value: t.name, @@ -35,7 +35,7 @@ const LLMToolsSelect = ({ value, onChange, disabled }: IProps) => {