From 0977e0f721b9f24e02e345a08e39c27451d15cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 19 Jan 2026 19:00:44 +0800 Subject: [PATCH 01/46] feature update from custom version of doc processor --- lazyllm/tools/rag/doc_node.py | 16 +- lazyllm/tools/rag/parsing_service/base.py | 18 +- lazyllm/tools/rag/parsing_service/impl.py | 75 +++++++- lazyllm/tools/rag/parsing_service/server.py | 68 +++++-- lazyllm/tools/rag/parsing_service/worker.py | 32 ++++ lazyllm/tools/rag/store/document_store.py | 5 +- .../tools/rag/store/hybrid/sensecore_store.py | 177 ++++++++++++------ lazyllm/tools/rag/store/store_base.py | 1 + lazyllm/tools/rag/store/utils.py | 54 +++++- 9 files changed, 355 insertions(+), 91 deletions(-) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 8b50fdd8f..4a5c970d8 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -54,6 +54,7 @@ def __init__(self, uid: Optional[str] = None, content: Optional[Union[str, List[ self.relevance_score = None self.similarity_score = None self._content_hash: Optional[str] = None + self._copy_source: Optional[dict] = None @property def uid(self) -> str: @@ -280,13 +281,24 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def to_dict(self) -> Dict: return dict(content=self._content, embedding=self.embedding, metadata=self.metadata) - def with_score(self, score): + def copy(self, global_metadata: dict=None, metadata: dict=None) -> 'DocNode': node = copy.copy(self) + node._copy_source = {'uid': self.uid, RAG_KB_ID: self.global_metadata.get(RAG_KB_ID), + RAG_DOC_ID: self.global_metadata.get(RAG_DOC_ID)} + node._uid = str(uuid.uuid4()) + if metadata: + node.metadata = node.metadata.update(metadata) + if global_metadata: + node.global_metadata = node.global_metadata.update(global_metadata) + return node + + def with_score(self, score): + node = self.copy() node.relevance_score = score return node def with_sim_score(self, score): - node = copy.copy(self) + node = self.copy() node.similarity_score = score return node diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 8e11a0d0f..c130a6201 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -1,15 +1,25 @@ -from pydantic import BaseModel, Field -from typing import Dict, List, Optional, Any +from pydantic import BaseModel, Field, BeforeValidator +from typing import Dict, List, Optional, Any, Annotated from enum import Enum from uuid import uuid4 from datetime import datetime +class TransferParams(BaseModel): + mode: Optional[str] = 'cp' # cp or mv + target_algo_id: str + target_doc_id: str + target_kb_id: str + + +EmptyTransfer = Annotated[TransferParams | None, BeforeValidator(lambda v: None if v == {} else v)] class FileInfo(BaseModel): file_path: Optional[str] = None doc_id: Optional[str] = None - metadata: Optional[Dict[str, Any]] = {} + metadata: Optional[Dict[str, Any]] = Field(default_factory=dict) reparse_group: Optional[str] = None + transformed_file_path: Optional[str] = None + transfer_params: EmptyTransfer = None class DBInfo(BaseModel): @@ -72,6 +82,7 @@ class TaskType(str, Enum): DOC_DELETE = 'DOC_DELETE' DOC_UPDATE_META = 'DOC_UPDATE_META' DOC_REPARSE = 'DOC_REPARSE' + DOC_TRANSFER = 'DOC_TRANSFER' def _get_task_type_weight(task_type: str) -> int: @@ -81,6 +92,7 @@ def _get_task_type_weight(task_type: str) -> int: TaskType.DOC_UPDATE_META.value: 30, TaskType.DOC_ADD.value: 100, TaskType.DOC_REPARSE.value: 100, + TaskType.DOC_TRANSFER.value: 100, } return weight_map.get(task_type, 100) diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index b19651b87..deccc89a6 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -41,9 +41,12 @@ def reader(self) -> DirectoryReader: return self._reader def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # noqa: C901 - metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None): + metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None, + transfer_mode: Optional[str] = None, target_kb_id: Optional[str] = None, + target_doc_ids: Optional[List[str]] = None): try: if not input_files: return + add_start = time.time() if not ids: ids = [gen_docid(path) for path in input_files] if metadatas is None: metadatas = [{} for _ in input_files] @@ -52,24 +55,76 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no metadata.setdefault(RAG_DOC_PATH, path) metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id - root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True) + + load_start = time.time() + if transfer_mode is None: + root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True) + else: + if transfer_mode not in ('cp', 'mv'): + raise ValueError(f'Invalid transfer mode: {transfer_mode}') + if len(ids) != len(target_doc_ids): + raise ValueError(f'The length of doc_ids and target_doc_ids must be the same. ' + f'doc_ids:{ids}, target_doc_ids:{target_doc_ids}') + doc_id_map = {ids[i]: (target_doc_ids[i], metadatas[i]) for i in range(len(ids))} + + root_nodes: List[DocNode] = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) + root_nodes = [ + n.copy( + global_metadata={ + RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[n.global_metadata[RAG_DOC_ID]][0] + }, + metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] + ) for n in root_nodes + ] + load_time = time.time() - load_start + schema_futures = [] schema_errors: List[Exception] = [] - if self._schema_extractor: + # run schema extraction in parallel + if self._schema_extractor and not transfer_mode: doc_to_root_nodes = defaultdict(list) for n in root_nodes[LAZY_ROOT_NAME]: doc_to_root_nodes[n.global_metadata.get(RAG_DOC_ID)].append(n) - + if doc_to_root_nodes: for nodes in doc_to_root_nodes.values(): schema_futures.append( self._thread_pool.submit(self._schema_extractor, nodes, algo_id=self._algo_id) ) - for k, v in root_nodes.items(): - if not v: continue - self._store.update_nodes(self._set_nodes_number(v)) - self._create_nodes_recursive(v, k) + if transfer_mode is None: + for k, v in root_nodes.items(): + if not v: continue + self._store.update_nodes(self._set_nodes_number(v)) + self._create_nodes_recursive(v, k) + else: + self._store.update_nodes(root_nodes, copy=True) + root_uid_map = {n._copy_source.get('uid'): n.uid for n in root_nodes} + def _copy_segments_recursive(p_uid_map: dict, p_name: str, **kwargs): + for group_name in self._store.activated_groups(): + group = self._node_groups.get(group_name) + if group is None: + raise ValueError(f'Node group {group_name} does not exist. Please check the group name ' + 'or add a new one through `create_node_group`.') + if group['parent'] == p_name: + nodes = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) + nodes = [ + n.copy( + global_metadata={ + RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[ + n.global_metadata[RAG_DOC_ID] + ][0] + }, + metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] + ) for n in nodes + ] + uid_map = {} + for n in nodes: + uid_map[n._copy_source.get('uid')] = n.uid + n.parent = p_uid_map.get(n.parent, None) if n.parent else None + self._store.update_nodes(nodes, copy=True) + if nodes: _copy_segments_recursive(uid_map, group_name) + _copy_segments_recursive(p_uid_map=root_uid_map, p_name=LAZY_ROOT_NAME) for future in schema_futures: try: @@ -79,7 +134,9 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no schema_errors.append(exc) if schema_errors: raise schema_errors[0] - LOG.info('Add documents done!') + add_time = time.time() - add_start + LOG.info(f'[_Processor - add_doc] Add documents done! files:{input_files}, ' + f'Total Time: {load_time}s, Load Time: {add_time}s') except Exception as e: LOG.error(f'Add documents failed: {e}, {traceback.format_exc()}') raise e diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index a0de7b1e5..ea8a46da6 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -14,7 +14,8 @@ from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, - TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, _calculate_task_score + TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, TransferParams, + _calculate_task_score ) from .worker import DocumentProcessorWorker as Worker from .queue import _SQLBasedQueue as Queue @@ -260,8 +261,7 @@ def add_doc(self, request: AddDocRequest): self._lazy_init() if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') - payload = request.model_dump() - LOG.info(f'[DocumentProcessor] Received add doc request: {payload}') + LOG.info(f'[DocumentProcessor] Received add doc request (raw): {request.model_dump()}') task_id = request.task_id algo_id = request.algo_id file_infos = request.file_infos @@ -273,24 +273,65 @@ def add_doc(self, request: AddDocRequest): # NOTE: No idempotency key check, should be handled by the caller! new_file_ids = [] reparse_file_ids = [] + + transfer_mode = None + target_algo_id = None + target_kb_id = None + for file_info in file_infos: - if self._path_prefix: - file_info.file_path = create_file_path(path=file_info.file_path, prefix=self._path_prefix) + parse_file_path = file_info.transformed_file_path if \ + file_info.transformed_file_path else file_info.file_path + file_info.file_path = create_file_path(parse_file_path, prefix=self._path_prefix) if file_info.reparse_group is not None: reparse_file_ids.append(file_info.doc_id) else: new_file_ids.append(file_info.doc_id) + if file_info.transfer_params: + if target_algo_id is not None and target_algo_id != file_info.transfer_params.target_algo_id: + raise fastapi.HTTPException( + status_code=400, + detail='transfer_params.target_algo_id must be the same for all files' + ) + if target_kb_id is not None and target_kb_id != file_info.transfer_params.target_kb_id: + raise fastapi.HTTPException( + status_code=400, + detail='transfer_params.target_kb_id must be the same for all files' + ) + if transfer_mode is not None and transfer_mode != file_info.transfer_params.mode: + raise fastapi.HTTPException( + status_code=400, + detail='transfer_params.mode must be the same for all files' + ) + # NOTE: currently we only support file transfer in the same algorithm + if file_info.transfer_params.target_algo_id != algo_id: + raise fastapi.HTTPException( + status_code=400, + detail='transfer_params.target_algo_id must be the same for all files' + ) + target_algo_id = file_info.transfer_params.target_algo_id + target_kb_id = file_info.transfer_params.target_kb_id + transfer_mode = file_info.transfer_params.mode + if transfer_mode not in ['cp', 'mv']: + raise fastapi.HTTPException( + status_code=400, + detail='transfer_params.mode must be one of [cp, mv]' + ) + if new_file_ids and reparse_file_ids: raise fastapi.HTTPException( status_code=400, detail='new_file_ids and reparse_file_ids cannot be specified at the same time' ) - if new_file_ids: + if transfer_mode: + task_type = TaskType.DOC_TRANSFER.value + elif new_file_ids: task_type = TaskType.DOC_ADD.value elif reparse_file_ids: task_type = TaskType.DOC_REPARSE.value else: raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + payload = request.model_dump() + LOG.info(f'[DocumentProcessor] Received add doc request: {payload}') payload_json = json.dumps(payload, ensure_ascii=False) try: @@ -321,8 +362,7 @@ def update_meta(self, request: UpdateMetaRequest): self._lazy_init() if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') - payload = request.model_dump() - LOG.info(f'update doc meta for {payload}') + LOG.info(f'[DocumentProcessor] Received update meta request (raw): {request.model_dump()}') task_id = request.task_id algo_id = request.algo_id file_infos = request.file_infos @@ -332,6 +372,8 @@ def update_meta(self, request: UpdateMetaRequest): algorithm = self._get_algo(algo_id) if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + payload = request.model_dump() + LOG.info(f'[DocumentProcessor] Received update meta request: {payload}') payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_UPDATE_META.value @@ -362,9 +404,7 @@ def delete_doc(self, request: DeleteDocRequest): self._lazy_init() if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') - payload = request.model_dump() - LOG.info(f'[DocumentProcessor] Received delete doc request: {payload}') - + LOG.info(f'[DocumentProcessor] Received delete doc request (raw): {request.model_dump()}') task_id = request.task_id algo_id = request.algo_id doc_ids = request.doc_ids @@ -373,7 +413,8 @@ def delete_doc(self, request: DeleteDocRequest): algorithm = self._get_algo(algo_id) if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') - + payload = request.model_dump() + LOG.info(f'[DocumentProcessor] Received delete doc request: {payload}') payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_DELETE.value @@ -404,8 +445,7 @@ def cancel(self, request: CancelTaskRequest): self._lazy_init() if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') - payload = request.model_dump() - LOG.info(f'[DocumentProcessor] Received cancel task request: {payload}') + LOG.info(f'[DocumentProcessor] Received cancel task request (raw): {request.model_dump()}') task_id = request.task_id try: # NOTE: only the task in waiting state can be canceled diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index ac768325e..43f55f4b7 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -35,6 +35,8 @@ def _lazy_init(self): table_name=WAITING_TASK_QUEUE_TABLE_INFO['name'], columns=WAITING_TASK_QUEUE_TABLE_INFO['columns'], db_config=self._db_config, + order_by='task_score', + order_desc=False, ) self._finished_task_queue = Queue( table_name=FINISHED_TASK_QUEUE_TABLE_INFO['name'], @@ -147,6 +149,34 @@ def _exec_reparse_task( except Exception as e: LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute reparse task failed, error: {e}') raise e + + def _exec_transfer_task(self, processor: _Processor, task_id: str, payload: dict): + try: + file_infos = payload.get('file_infos') + kb_id = payload.get('kb_id', None) + input_files = [] + ids = [] + metadatas = [] + + transfer_mode = None + target_kb_id = None + target_doc_ids = [] + + for file_info in file_infos: + input_files.append(file_info.get('file_path')) + ids.append(file_info.get('doc_id')) + metadatas.append(file_info.get('metadata')) + if transfer_mode is None: + transfer_mode = file_info.get('transfer_params', {}).get('mode') + if target_kb_id is None: + target_kb_id = file_info.get('transfer_params', {}).get('target_kb_id') + target_doc_ids.append(file_info.get('transfer_params', {}).get('target_doc_id')) + processor.add_doc(input_files=input_files, ids=ids, metadatas=metadatas, kb_id=kb_id, + transfer_mode=transfer_mode, target_kb_id=target_kb_id, target_doc_ids=target_doc_ids) + + except Exception as e: + LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute transfer task failed, error: {e}') + raise e def _exec_delete_task(self, processor: _Processor, task_id: str, payload: dict): try: @@ -220,6 +250,8 @@ def _worker_impl(self): self._exec_delete_task(processor, task_id, payload) elif task_type == TaskType.DOC_UPDATE_META.value: self._exec_update_meta_task(processor, task_id, payload) + elif task_type == TaskType.DOC_TRANSFER.value: + self._exec_transfer_task(processor, task_id, payload) else: raise ValueError(f'[DocumentProcessorWorker._Impl] Unknown task type: {task_type}') diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index feea84a45..406223d6d 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -153,14 +153,14 @@ def is_group_active(self, group: str) -> bool: def is_group_empty(self, group: str) -> bool: return not self.impl.get(self._gen_collection_name(group), {}, limit=10) - def update_nodes(self, nodes: List[DocNode]): # noqa: C901 + def update_nodes(self, nodes: List[DocNode], copy: bool = False): # noqa: C901 if not nodes: return try: if self._embed and self.impl.capability == StoreCapability.SEGMENT: LOG.warning(f'[_DocumentStore - {self._algo_name}] Embed is provided' f' but store {self.impl} does not support embedding') - if self.impl.need_embedding: + if self.impl.need_embedding and not copy: parallel_do_embedding(self._embed, [], nodes, self._group_embed_keys) group_segments = defaultdict(list) for node in nodes: @@ -385,6 +385,7 @@ def _serialize_node(self, node: DocNode) -> dict: kb_id=node.global_metadata.get(RAG_KB_ID, DEFAULT_KB_ID), excluded_embed_metadata_keys=node.excluded_embed_metadata_keys, excluded_llm_metadata_keys=node.excluded_llm_metadata_keys, + copy_source=node._copy_source or {}, ) if node.parent: segment.parent = node.parent._uid if isinstance(node.parent, DocNode) else node.parent diff --git a/lazyllm/tools/rag/store/hybrid/sensecore_store.py b/lazyllm/tools/rag/store/hybrid/sensecore_store.py index 2c3b9a89f..355ffb9c3 100644 --- a/lazyllm/tools/rag/store/hybrid/sensecore_store.py +++ b/lazyllm/tools/rag/store/hybrid/sensecore_store.py @@ -2,7 +2,6 @@ import json import uuid import time -import random import requests from pydantic import BaseModel, Field @@ -11,15 +10,17 @@ from ..store_base import (LazyLLMStoreBase, StoreCapability, LAZY_ROOT_NAME, IMAGE_PATTERN, INSERT_BATCH_SIZE, DEFAULT_KB_ID, SegmentType) -from ..utils import upload_data_to_s3, download_data_from_s3, fibonacci_backoff, create_file_path +from ..utils import upload_data_to_s3, download_data_from_s3, fibonacci_backoff, create_file_path, presign_obj_from_s3 from ...data_type import DataType from ...global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_KB_ID -from lazyllm import warp, pipeline, LOG, config +from lazyllm import warp, pipeline, LOG, config, package from lazyllm.common import override from lazyllm.thirdparty import boto3 +PRESIGN_EXPIRE_TIME = 3600 * 24 * 7 + class Segment(BaseModel): segment_id: str @@ -37,6 +38,7 @@ class Segment(BaseModel): answer: Optional[str] = '' image_keys: Optional[List[str]] = Field(default_factory=list) number: Optional[int] = 0 + copy_source: Optional[Dict[str, str]] = Field(default_factory=dict) class SenseCoreStore(LazyLLMStoreBase): @@ -48,6 +50,7 @@ def __init__(self, uri: str = '', **kwargs): self._uri = uri self._s3_config = kwargs.get('s3_config') self._image_url_config = kwargs.get('image_url_config') + self._uploaded_image_keys = set() @property def dir(self): @@ -68,38 +71,75 @@ def _check_s3(self): LOG.info(f'[SenseCore Store - check_s3] uploaded warmup.txt to {self._s3_config["bucket_name"]}') return + def _upload_image_if_needed(self, file_path: str, obj_key: str): + if obj_key in self._uploaded_image_keys: + return + + with open(file_path, 'rb') as f: + upload_data_to_s3( + f.read(), + bucket_name=self._s3_config['bucket_name'], + object_key=obj_key, + aws_access_key_id=self._s3_config['access_key'], + aws_secret_access_key=self._s3_config['secret_access_key'], + use_minio=self._s3_config['use_minio'], + endpoint_url=self._s3_config['endpoint_url'] + ) + + self._uploaded_image_keys.add(obj_key) + def _serialize_data(self, data: dict) -> Dict: # noqa: C901 data = dict(data) content = json.dumps(data.get('content', ''), ensure_ascii=False) matches = IMAGE_PATTERN.findall(content) + doc_id = data.get('doc_id', '') + kb_id = data.get(RAG_KB_ID, DEFAULT_KB_ID) for _, image_path in matches: if image_path.startswith('lazyllm'): continue image_file_name = os.path.basename(image_path) - obj_key = f'lazyllm/images/{image_file_name}' + obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_file_name}' try: prefix = config['image_path_prefix'] except Exception: prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') file_path = create_file_path(path=image_path, prefix=prefix) try: - with open(file_path, 'rb') as f: - upload_data_to_s3(f.read(), bucket_name=self._s3_config['bucket_name'], object_key=obj_key, - aws_access_key_id=self._s3_config['access_key'], - aws_secret_access_key=self._s3_config['secret_access_key'], - use_minio=self._s3_config['use_minio'], - endpoint_url=self._s3_config['endpoint_url']) - content = content.replace(image_path, obj_key) + self._upload_image_if_needed(file_path, obj_key) + content = content.replace(image_path, obj_key) except FileNotFoundError: LOG.error(f'Cannot find image path: {image_path} (local path {file_path}), skip...') except Exception as e: LOG.error(f'Error when uploading `{image_path}` {e!r}') - finally: - time.sleep(0.1 + random.random() * 0.4) data['content'] = json.loads(content) + # special requirement: item called `table_image_map` in metadata, need to upload to s3 + if data.get('meta', {}).get('table_image_map', {}): + for k, md_info in data['meta']['table_image_map'].items(): + matches = IMAGE_PATTERN.findall(md_info) + if not matches: + continue + image_path = matches[0][1] + if image_path.startswith('lazyllm'): + continue + image_name = os.path.basename(image_path) + obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_name}' + try: + prefix = config['image_path_prefix'] + except Exception: + prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') + file_path = create_file_path(path=image_path, prefix=prefix) + try: + self._upload_image_if_needed(file_path, obj_key) + md_info = md_info.replace(image_path, obj_key) + data['meta']['table_image_map'][k] = md_info + except FileNotFoundError: + LOG.error(f'Cannot find image: {image_path} (local path {file_path}, obj key {obj_key}), skip...') + except Exception as e: + LOG.error(f'Error when uploading `{image_path}` (local path {file_path}, obj key {obj_key}) {e!r}') + if data.get('group') == LAZY_ROOT_NAME: - obj_key = f'lazyllm/lazyllm_root/{data.get("uid")}.json' + obj_key = f"lazyllm/lazyllm_root/{kb_id}/{doc_id}/{data.get('uid')}.json" upload_data_to_s3(content.encode('utf-8'), bucket_name=self._s3_config['bucket_name'], object_key=obj_key, aws_access_key_id=self._s3_config['access_key'], aws_secret_access_key=self._s3_config['secret_access_key'], @@ -114,6 +154,12 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 parent=data.get('parent', ''), global_meta=json.dumps(data.get('global_meta', {}), ensure_ascii=False), answer=data.get('answer', ''), number=data.get('number', 0)) + if len(data.get('copy_source', {})): + segment.copy_source = { + 'dataset_id': data.get('copy_source', {}).get(RAG_KB_ID, DEFAULT_KB_ID), + 'document_id': data.get('copy_source', {}).get(RAG_DOC_ID, ''), + 'segment_id': data.get('copy_source', {}).get('uid', '') + } # image extract if isinstance(segment.content, str): target = segment.content @@ -126,14 +172,9 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 if data.get('type') == SegmentType.IMAGE.value and data.get('image_keys'): image_path = data.get('image_keys', [])[0] image_file_name = os.path.basename(image_path) - obj_key = f'lazyllm/images/{image_file_name}' + obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_file_name}' try: - with open(image_path, 'rb') as f: - upload_data_to_s3(f.read(), bucket_name=self._s3_config['bucket_name'], object_key=obj_key, - aws_access_key_id=self._s3_config['access_key'], - aws_secret_access_key=self._s3_config['secret_access_key'], - use_minio=self._s3_config['use_minio'], - endpoint_url=self._s3_config['endpoint_url']) + self._upload_image_if_needed(image_path, obj_key) segment.image_keys = [obj_key] except FileNotFoundError: LOG.error(f'Cannot find image path: {image_path} (local path {image_path}), skip...') @@ -146,20 +187,15 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 if image_path.startswith('lazyllm'): continue image_file_name = os.path.basename(image_path) - obj_key = f'lazyllm/images/{image_file_name}' + obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_file_name}' try: prefix = config['image_path_prefix'] except Exception: prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') file_path = create_file_path(path=image_path, prefix=prefix) try: - with open(file_path, 'rb') as f: - upload_data_to_s3(f.read(), bucket_name=self._s3_config['bucket_name'], object_key=obj_key, - aws_access_key_id=self._s3_config['access_key'], - aws_secret_access_key=self._s3_config['secret_access_key'], - use_minio=self._s3_config['use_minio'], - endpoint_url=self._s3_config['endpoint_url']) - answer = answer.replace(image_path, obj_key) + self._upload_image_if_needed(file_path, obj_key) + answer = answer.replace(image_path, obj_key) except FileNotFoundError: LOG.error(f'Cannot find image path: {image_path} (local path {file_path}), skip...') except Exception as e: @@ -171,12 +207,12 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 segment.answer = data['answer'] return segment.model_dump() - def _deserialize_data(self, segment: Dict) -> Dict: + def _deserialize_data(self, segment: Dict, display: bool = False) -> Dict: data = { 'uid': segment.get('segment_id', ''), 'doc_id': segment.get('document_id', ''), 'group': segment.get('group', ''), - 'content': segment.get('content', ''), + 'content': segment.get('content', '') if not display else segment.get('display_content'), 'meta': json.loads(segment.get('meta', '{}')), 'global_meta': json.loads(segment.get('global_meta', '{}')), 'number': segment.get('number', 0), @@ -199,6 +235,27 @@ def _deserialize_data(self, segment: Dict) -> Dict: use_minio=self._s3_config['use_minio'], endpoint_url=self._s3_config['endpoint_url'], encoding='utf-8') data['content'] = json.loads(content) + + if display and data.get('meta', {}).get('table_image_map', {}): + for k, v in data['meta']['table_image_map'].items(): + matches = IMAGE_PATTERN.findall(v) + if not matches: + continue + image_path = matches[0][1] + if not image_path.startswith('lazyllm'): + LOG.warning(f'[SenseCore Store]: table_image value must start with lazyllm, value: {image_path}') + continue + url = presign_obj_from_s3( + bucket_name=self._s3_config['bucket_name'], + object_key=image_path, + aws_access_key_id=self._s3_config['access_key'], + aws_secret_access_key=self._s3_config['secret_access_key'], + endpoint_url=self._s3_config.get('external_endpoint_url', self._s3_config['endpoint_url']), + region_name=self._s3_config.get('region_name', 'us-east-1'), + client_method='get_object', + expires_in=PRESIGN_EXPIRE_TIME, + ) + data['meta']['table_image_map'][k] = v.replace(image_path, url) return data def _create_filters_str(self, filters: Dict[str, Union[str, int, List, Set]]) -> str: @@ -221,7 +278,7 @@ def _create_filters_str(self, filters: Dict[str, Union[str, int, List, Set]]) -> return ret_str[:-5] # truncate the last ' and ' return ret_str - def _upload_data_and_insert(self, data: List[dict]) -> str: + def _upload_data_and_insert(self, data: List[dict], job_type: str = 'insert') -> str: try: job_id = str(uuid.uuid4()) groups = set() @@ -245,11 +302,12 @@ def _upload_data_and_insert(self, data: List[dict]) -> str: url = urljoin(self._uri, 'v1/writerSegmentJob:submit') params = {'writer_segment_job_id': job_id} headers = {'Accept': 'application/json', 'Content-Type': 'application/json'} - payload = {'dataset_id': dataset_id or self._kb_id, 'file_key': obj_key, 'groups': groups} + payload = {'dataset_id': dataset_id or self._kb_id, 'file_key': obj_key, + 'groups': groups, 'job_type': job_type} response = requests.post(url, params=params, headers=headers, json=payload) response.raise_for_status() - LOG.info(f'SenseCore Store: insert task {job_id} submitted') + LOG.info(f'SenseCore Store: insert task {job_id} submitted, payload:{payload}') except Exception as e: LOG.error(f'SenseCore Store: insert task {job_id} failed: {e}') raise e @@ -258,18 +316,25 @@ def _upload_data_and_insert(self, data: List[dict]) -> str: def _check_insert_job_status(self, job_id: str) -> None: url = urljoin(self._uri, f'v1/writerSegmentJobs/{job_id}') headers = {'Accept': 'application/json'} - for wait_time in fibonacci_backoff(max_retries=15): + check_start_time = time.time() + flag = False + for wait_time in fibonacci_backoff(max_retries=16): response = requests.get(url, headers=headers) response.raise_for_status() status = response.json()['state'] if status == 2: - LOG.info(f'SenseCore Store: insert task {job_id} finished') - return + flag = True + break elif status == 3: - raise Exception(f'Insert task {job_id} failed') + break else: time.sleep(wait_time) - raise Exception(f'Insert task {job_id} failed after seconds') + check_end_time = time.time() + if not flag: + LOG.error(f'SenseCore Store: insert task {job_id} failed after {check_end_time - check_start_time}s') + raise Exception(f'Insert task {job_id} failed after {check_end_time - check_start_time}s') + LOG.info(f'SenseCore Store: insert task {job_id} finished after {check_end_time - check_start_time}s') + return def _get_group_name(self, collection_name: str) -> str: return collection_name.split('_')[-1] if 'lazyllm_root' not in collection_name else 'lazyllm_root' @@ -278,12 +343,19 @@ def _get_group_name(self, collection_name: str) -> str: def upsert(self, collection_name: str, data: List[dict]) -> bool: if not data: return True try: + upsert_start_time = time.time() + job_type = 'insert' if not len(data[0].get('copy_source', {})) else 'copy' with pipeline() as insert_ppl: insert_ppl.get_ids = warp(self._upload_data_and_insert).aslist insert_ppl.check_status = warp(self._check_insert_job_status) - batched_data = [data[i:i + INSERT_BATCH_SIZE] for i in range(0, len(data), INSERT_BATCH_SIZE)] + batched_data = [ + package(data[i:i + INSERT_BATCH_SIZE], job_type) for i in range(0, len(data), INSERT_BATCH_SIZE) + ] insert_ppl(batched_data) + upsert_end_time = time.time() + LOG.info(f'[SenseCore Store - upsert] Upsert done! collection_name:{collection_name}, ' + f'Time:{upsert_end_time - upsert_start_time}s') return True except Exception as e: LOG.error(f'[SenseCore Store - upsert] insert task failed: {e}') @@ -332,7 +404,7 @@ def get(self, collection_name: str, criteria: dict, **kwargs) -> List[dict]: # if uids: payload['segment_ids'] = uids else: - payload['page_size'] = 100 + payload['page_size'] = 1000 segments = [] while True: response = requests.post(url, headers=headers, json=payload) @@ -350,23 +422,11 @@ def get(self, collection_name: str, criteria: dict, **kwargs) -> List[dict]: # payload['page_token'] = next_page_token if doc_ids: segments = [segment for segment in segments if segment['document_id'] in doc_ids] - if kwargs.get('display'): - segments = self._apply_display(segments) - return [self._deserialize_data(s) for s in segments] + return [self._deserialize_data(s, display=kwargs.get('display', False)) for s in segments] except Exception as e: LOG.error(f'[SenseCore Store - get]:task failed: {e}') return [] - def _apply_display(self, segments: List[dict]) -> List[dict]: - out = [] - for s in segments: - if not s.get('is_active', True): - continue - if s.get('display_content'): - s['content'] = s['display_content'] - out.append(s) - return out - def _multi_modal_process(self, query: str, images: List[str]): urls = [] s3 = boto3.client('s3', aws_access_key_id=self._image_url_config['access_key'], @@ -417,14 +477,11 @@ def search(self, collection_name: str, query: Union[str, dict, List[float]], top payload = {'query': query, 'hybrid_search_datasets': hybrid_search_datasets, 'hybrid_search_type': 2, 'top_k': topk, 'filters': filter_str, 'group': self._get_group_name(collection_name), 'embedding_model': embed_key, 'images': images} - response = requests.post(url, headers=headers, json=payload) + response = requests.post(url, headers=headers, json=payload, timeout=60) response.raise_for_status() segments = response.json()['segments'] - segments = [s for s in segments if s['is_active']] - for s in segments: - if len(s.get('display_content', '')): - s['content'] = s['display_content'] - return [self._deserialize_data(s) for s in segments] + segments = [s for s in segments if s.get('is_active', True)] + return [self._deserialize_data(s, display=True) for s in segments] except Exception as e: LOG.error(f'SenseCore Store: query task failed: {e}') raise e diff --git a/lazyllm/tools/rag/store/store_base.py b/lazyllm/tools/rag/store/store_base.py index caefc7fb2..3897de564 100644 --- a/lazyllm/tools/rag/store/store_base.py +++ b/lazyllm/tools/rag/store/store_base.py @@ -60,6 +60,7 @@ class Segment(BaseModel): parent: Optional[str] = None # uid of parent node answer: Optional[str] = '' image_keys: Optional[List[str]] = Field(default_factory=list) + copy_source: Optional[Dict[str, str]] = Field(default_factory=dict) class StoreCapability(IntFlag): diff --git a/lazyllm/tools/rag/store/utils.py b/lazyllm/tools/rag/store/utils.py index 9a3bf48ba..c2666ddc3 100644 --- a/lazyllm/tools/rag/store/utils.py +++ b/lazyllm/tools/rag/store/utils.py @@ -9,7 +9,7 @@ from lazyllm import LOG from lazyllm.thirdparty import boto3 -from typing import Optional, Union +from typing import Optional, Union, Dict, Any from io import BytesIO INSERT_MAX_RETRIES = 10 @@ -184,6 +184,58 @@ def download_data_from_s3( except OSError: pass +def presign_obj_from_s3( + bucket_name: str, + object_key: str, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + endpoint_url: Optional[str] = None, + region_name: str = "us-east-1", + client_method: str = "get_object", + expires_in: int = 3600, + extra_params: Optional[Dict[str, Any]] = None, +) -> str: + + spec = importlib.util.find_spec("botocore.client") + if spec is None: + raise ImportError( + "Please install boto3 to use botocore module. " + "You can install it with `pip install boto3`" + ) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + Config = m.Config + + s3_client = boto3.client( + "s3", + region_name=region_name, + endpoint_url=endpoint_url, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + config=Config( + signature_version="s3v4", + ), + ) + + params = { + "Bucket": bucket_name, + "Key": object_key, + } + + if extra_params: + params.update(extra_params) + + try: + url = s3_client.generate_presigned_url( + ClientMethod=client_method, + Params=params, + ExpiresIn=expires_in, + ) + return url + except Exception as e: + LOG.error(f"Generate presigned url failed: {e}") + raise e + def fibonacci_backoff(max_retries: int = INSERT_MAX_RETRIES): a, b = 1, 1 for _ in range(max_retries): From 35000406539304c441a2db5a0ad8617a46d54738 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 20 Jan 2026 19:58:05 +0800 Subject: [PATCH 02/46] enhance worker for parsing service --- lazyllm/tools/rag/doc_node.py | 2 +- lazyllm/tools/rag/parsing_service/base.py | 10 +- lazyllm/tools/rag/parsing_service/impl.py | 9 +- lazyllm/tools/rag/parsing_service/queue.py | 72 +++++ lazyllm/tools/rag/parsing_service/server.py | 95 +++++- lazyllm/tools/rag/parsing_service/worker.py | 320 +++++++++++++++++--- lazyllm/tools/rag/store/utils.py | 20 +- lazyllm/tools/sql/sql_manager.py | 3 +- 8 files changed, 462 insertions(+), 69 deletions(-) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 4a5c970d8..3dc0a1f95 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -281,7 +281,7 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def to_dict(self) -> Dict: return dict(content=self._content, embedding=self.embedding, metadata=self.metadata) - def copy(self, global_metadata: dict=None, metadata: dict=None) -> 'DocNode': + def copy(self, global_metadata: dict = None, metadata: dict = None) -> 'DocNode': node = copy.copy(self) node._copy_source = {'uid': self.uid, RAG_KB_ID: self.global_metadata.get(RAG_KB_ID), RAG_DOC_ID: self.global_metadata.get(RAG_DOC_ID)} diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index c130a6201..d7274d05e 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -120,8 +120,16 @@ def _calculate_task_score(task_type: str, user_priority: int) -> int: 'comment': 'Calculated task score (for sorting, lower score is higher priority)'}, {'name': 'message', 'data_type': 'string', 'nullable': False, 'comment': 'Task message (json string, serialized from request body)'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'default': TaskStatus.WAITING.value, + 'comment': 'Task status: WAITING, WORKING'}, + {'name': 'worker_id', 'data_type': 'string', 'nullable': True, + 'comment': 'Worker ID holding the lease'}, + {'name': 'lease_expires_at', 'data_type': 'datetime', 'nullable': True, + 'comment': 'Lease expiration time for in-progress task'}, {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, - 'comment': 'Creation time (auto-generated)', 'default': datetime.now()}, + 'comment': 'Creation time (auto-generated)', 'default': datetime.now}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, + 'comment': 'Last update time (auto-generated)', 'default': datetime.now}, ] } diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index deccc89a6..396c46cac 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -64,7 +64,7 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no raise ValueError(f'Invalid transfer mode: {transfer_mode}') if len(ids) != len(target_doc_ids): raise ValueError(f'The length of doc_ids and target_doc_ids must be the same. ' - f'doc_ids:{ids}, target_doc_ids:{target_doc_ids}') + f'doc_ids:{ids}, target_doc_ids:{target_doc_ids}') doc_id_map = {ids[i]: (target_doc_ids[i], metadatas[i]) for i in range(len(ids))} root_nodes: List[DocNode] = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) @@ -85,7 +85,7 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no doc_to_root_nodes = defaultdict(list) for n in root_nodes[LAZY_ROOT_NAME]: doc_to_root_nodes[n.global_metadata.get(RAG_DOC_ID)].append(n) - + if doc_to_root_nodes: for nodes in doc_to_root_nodes.values(): schema_futures.append( @@ -100,12 +100,13 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no else: self._store.update_nodes(root_nodes, copy=True) root_uid_map = {n._copy_source.get('uid'): n.uid for n in root_nodes} - def _copy_segments_recursive(p_uid_map: dict, p_name: str, **kwargs): + + def _copy_segments_recursive(p_uid_map: dict, p_name: str): for group_name in self._store.activated_groups(): group = self._node_groups.get(group_name) if group is None: raise ValueError(f'Node group {group_name} does not exist. Please check the group name ' - 'or add a new one through `create_node_group`.') + 'or add a new one through `create_node_group`.') if group['parent'] == p_name: nodes = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) nodes = [ diff --git a/lazyllm/tools/rag/parsing_service/queue.py b/lazyllm/tools/rag/parsing_service/queue.py index 01c9f6a49..785cfeb45 100644 --- a/lazyllm/tools/rag/parsing_service/queue.py +++ b/lazyllm/tools/rag/parsing_service/queue.py @@ -1,4 +1,5 @@ from typing import Dict, List, Optional, Any +from datetime import datetime, timedelta from lazyllm import LOG from ...sql import SqlManager @@ -89,6 +90,77 @@ def dequeue(self, filter_by: Dict[str, Any] = None) -> Optional[Dict[str, Any]]: LOG.error(f'[SQLBasedQueue] Failed to dequeue from {self._table_name}: {e}') raise + def claim(self, worker_id: str, lease_duration: float, + status_waiting: str = 'WAITING', status_working: str = 'WORKING', + filter_by: Dict[str, Any] = None, include_task_types: List[str] = None, + exclude_task_types: List[str] = None) -> Optional[Dict[str, Any]]: + '''claim a message from the queue without removing it''' + try: + with self._sql_manager.get_session() as session: + TableCls = self._sql_manager.get_table_orm_class(self._table_name) + query = self._build_query(session, filter_by) + now = datetime.now() + if include_task_types: + query = query.filter(TableCls.task_type.in_(include_task_types)) + if exclude_task_types: + query = query.filter(~TableCls.task_type.in_(exclude_task_types)) + query = query.filter( + (TableCls.status == status_waiting) + | ((TableCls.status == status_working) + & ((TableCls.lease_expires_at < now) + | (TableCls.lease_expires_at.is_(None)))) + ) + record = query.with_for_update().first() + if not record: + return None + + record.status = status_working + record.worker_id = worker_id + record.lease_expires_at = now + timedelta(seconds=lease_duration) + record.updated_at = now + session.flush() + result = _orm_to_dict(record) + LOG.info(f'[SQLBasedQueue] Claimed from {self._table_name}') + return result + except Exception as e: + LOG.error(f'[SQLBasedQueue] Failed to claim from {self._table_name}: {e}') + raise + + def extend_lease(self, task_id: str, worker_id: str, lease_duration: float) -> bool: + try: + with self._sql_manager.get_session() as session: + TableCls = self._sql_manager.get_table_orm_class(self._table_name) + now = datetime.now() + record = session.query(TableCls).filter( + TableCls.task_id == task_id, + TableCls.worker_id == worker_id + ).first() + if not record: + return False + record.lease_expires_at = now + timedelta(seconds=lease_duration) + record.updated_at = now + session.flush() + LOG.debug(f'[SQLBasedQueue] Extended lease for {self._table_name}') + return True + except Exception as e: + LOG.error(f'[SQLBasedQueue] Failed to extend lease for {self._table_name}: {e}') + raise + + def delete(self, filter_by: Dict[str, Any]) -> int: + try: + with self._sql_manager.get_session() as session: + TableCls = self._sql_manager.get_table_orm_class(self._table_name) + query = session.query(TableCls) + if filter_by: + for key, value in filter_by.items(): + query = query.filter(getattr(TableCls, key) == value) + count = query.delete(synchronize_session=False) + LOG.info(f'[SQLBasedQueue] Deleted {count} records from {self._table_name}') + return count + except Exception as e: + LOG.error(f'[SQLBasedQueue] Failed to delete from {self._table_name}: {e}') + raise + def peek(self, filter_by: Dict[str, Any] = None) -> Optional[Dict[str, Any]]: try: with self._sql_manager.get_session() as session: diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index ea8a46da6..ce05f2a0a 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -4,7 +4,7 @@ import traceback import cloudpickle from datetime import datetime -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, List from lazyllm import ( LOG, ModuleBase, ServerModule, UrlModule, FastapiApp as app, @@ -14,7 +14,7 @@ from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, - TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, TransferParams, + TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, _calculate_task_score ) from .worker import DocumentProcessorWorker as Worker @@ -32,7 +32,10 @@ class DocumentProcessor(ModuleBase): class _Impl(): def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int = 1, - post_func: Optional[Callable] = None, path_prefix: Optional[str] = None): + post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, + lease_duration: float = 300.0, lease_renew_interval: float = 60.0, + high_priority_task_types: Optional[List[str]] = None, + high_priority_workers: int = 1): self._db_config = db_config self._num_workers = num_workers self._post_func = post_func @@ -40,12 +43,21 @@ def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int raise ValueError('Invalid post function!') self._shutdown = False self._path_prefix = path_prefix + self._lease_duration = lease_duration + self._lease_renew_interval = lease_renew_interval + self._high_priority_task_types = ( + high_priority_task_types + if high_priority_task_types is not None + else [TaskType.DOC_DELETE.value] + ) + self._high_priority_workers = max(high_priority_workers, 0) self._db_manager = None self._waiting_task_queue = None self._finished_task_queue = None self._post_func_thread = None self._workers = None + self._high_priority_workers_module = None @once_wrapper(reset_on_pickle=True) def _lazy_init(self): @@ -69,8 +81,35 @@ def _lazy_init(self): self._post_func_thread.start() if self._num_workers > 0: - self._workers = Worker(db_config=self._db_config, num_workers=self._num_workers) - self._workers.start() + high_priority_types = [t for t in (self._high_priority_task_types or []) if t] + high_priority_workers = 0 + if high_priority_types and self._high_priority_workers > 0: + if self._num_workers <= 1: + LOG.warning('[DocumentProcessor] num_workers <= 1, high priority workers disabled') + else: + high_priority_workers = min(self._high_priority_workers, self._num_workers - 1) + if high_priority_workers < self._high_priority_workers: + LOG.warning('[DocumentProcessor] high_priority_workers trimmed to fit num_workers') + normal_workers = self._num_workers - high_priority_workers + if high_priority_workers > 0: + self._high_priority_workers_module = Worker( + db_config=self._db_config, + num_workers=high_priority_workers, + lease_duration=self._lease_duration, + lease_renew_interval=self._lease_renew_interval, + high_priority_task_types=high_priority_types, + high_priority_only=True, + ) + self._high_priority_workers_module.start() + if normal_workers > 0: + self._workers = Worker( + db_config=self._db_config, + num_workers=normal_workers, + lease_duration=self._lease_duration, + lease_renew_interval=self._lease_renew_interval, + high_priority_task_types=high_priority_types, + ) + self._workers.start() LOG.info('[DocumentProcessor] Lazy initialization completed!') def __getstate__(self): @@ -80,6 +119,7 @@ def __getstate__(self): state['_finished_task_queue'] = None state['_post_func_thread'] = None state['_workers'] = None + state['_high_priority_workers_module'] = None return state def __setstate__(self, state): @@ -257,7 +297,7 @@ def get_algo_group_info(self, algo_id: str) -> None: raise fastapi.HTTPException(status_code=500, detail=f'Failed to get group info: {str(e)}') @app.post('/doc/add') - def add_doc(self, request: AddDocRequest): + def add_doc(self, request: AddDocRequest): # noqa: C901 self._lazy_init() if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') @@ -337,13 +377,18 @@ def add_doc(self, request: AddDocRequest): try: user_priority = request.priority if request.priority is not None else 0 task_score = _calculate_task_score(task_type, user_priority) + now = datetime.now() self._waiting_task_queue.enqueue( task_id=task_id, task_type=task_type, user_priority=user_priority, task_score=task_score, message=payload_json, - created_at=datetime.now(), + status=TaskStatus.WAITING.value, + worker_id=None, + lease_expires_at=None, + created_at=now, + updated_at=now, ) LOG.info(f'[DocumentProcessor] Task {task_id} (type={task_type}, user_priority={user_priority}, ' f'score={task_score}) submitted to database queue successfully') @@ -379,13 +424,18 @@ def update_meta(self, request: UpdateMetaRequest): task_type = TaskType.DOC_UPDATE_META.value user_priority = request.priority if request.priority is not None else 0 task_score = _calculate_task_score(task_type, user_priority) + now = datetime.now() self._waiting_task_queue.enqueue( task_id=task_id, task_type=task_type, user_priority=user_priority, task_score=task_score, message=payload_json, - created_at=datetime.now(), + status=TaskStatus.WAITING.value, + worker_id=None, + lease_expires_at=None, + created_at=now, + updated_at=now, ) LOG.info(f'[DocumentProcessor] Update meta task {task_id} (user_priority={user_priority}, ' f'score={task_score}) submitted to database queue successfully') @@ -420,13 +470,18 @@ def delete_doc(self, request: DeleteDocRequest): task_type = TaskType.DOC_DELETE.value user_priority = request.priority if request.priority is not None else 0 task_score = _calculate_task_score(task_type, user_priority) + now = datetime.now() self._waiting_task_queue.enqueue( task_id=task_id, task_type=task_type, user_priority=user_priority, task_score=task_score, message=payload_json, - created_at=datetime.now(), + status=TaskStatus.WAITING.value, + worker_id=None, + lease_expires_at=None, + created_at=now, + updated_at=now, ) LOG.info(f'[DocumentProcessor] Delete task {task_id} (user_priority={user_priority}, ' f'score={task_score}) submitted to database queue successfully') @@ -450,9 +505,11 @@ def cancel(self, request: CancelTaskRequest): try: # NOTE: only the task in waiting state can be canceled cancel_status = False - waiting_task = self._waiting_task_queue.dequeue(filter_by={'task_id': task_id}) + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'status': TaskStatus.WAITING.value} + ) message = '' - if waiting_task: + if deleted: cancel_status = True message = 'Canceled by user' else: @@ -510,14 +567,24 @@ def __call__(self, func_name: str, *args, **kwargs): def __init__(self, port: int = None, url: str = None, num_workers: int = 1, db_config: Optional[Dict[str, Any]] = None, launcher: Optional[Launcher] = None, post_func: Optional[Callable] = None, - path_prefix: Optional[str] = None): + path_prefix: Optional[str] = None, lease_duration: float = 300.0, + lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, + high_priority_workers: int = 1): super().__init__() self._raw_impl = None # save the reference of the original Impl object self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') if not url: # create the Impl object (lazy loading, no threads created) - self._raw_impl = DocumentProcessor._Impl(num_workers=num_workers, db_config=self._db_config, - post_func=post_func, path_prefix=path_prefix) + self._raw_impl = DocumentProcessor._Impl( + num_workers=num_workers, + db_config=self._db_config, + post_func=post_func, + path_prefix=path_prefix, + lease_duration=lease_duration, + lease_renew_interval=lease_renew_interval, + high_priority_task_types=high_priority_task_types, + high_priority_workers=high_priority_workers, + ) self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) else: self._impl = UrlModule(url=ensure_call_endpoint(url)) diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 43f55f4b7..e2d887101 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -1,15 +1,19 @@ import json +import os +import subprocess import time import traceback import threading import cloudpickle from datetime import datetime +from uuid import uuid4 from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, once_wrapper from ..utils import BaseResponse, _get_default_db_config from .base import ( FINISHED_TASK_QUEUE_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, - TaskStatus, TaskType, ALGORITHM_TABLE_INFO + TaskStatus, TaskType, ALGORITHM_TABLE_INFO, AddDocRequest, UpdateMetaRequest, + DeleteDocRequest, _calculate_task_score ) from .impl import _Processor from .queue import _SQLBasedQueue as Queue @@ -21,13 +25,28 @@ class DocumentProcessorWorker(ModuleBase): class _Impl(): - def __init__(self, db_config: dict = None): + def __init__(self, db_config: dict = None, task_poller=None, lease_duration: float = 300.0, + lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, + high_priority_only: bool = False): self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._shutdown = False self._processors: dict[str, _Processor] = {} # algo_id -> _Processor self._waiting_task_queue = None self._finished_task_queue = None self._worker_thread = None + self._poller_thread = None + if task_poller is not None and not callable(task_poller): + raise TypeError('task_poller is not callable') + self._task_poller = task_poller + self._task_poller_impl = self._wrap_task_poller(task_poller) if task_poller else None + self._worker_id = f'{self._get_worker_identity()}-{uuid4()}' + self._in_progress_task = None + self._lease_thread = None + self._lease_stop_event = None + self._lease_duration = lease_duration + self._lease_renew_interval = lease_renew_interval + self._high_priority_task_types = set(high_priority_task_types or []) + self._high_priority_only = high_priority_only @once_wrapper(reset_on_pickle=True) def _lazy_init(self): @@ -48,7 +67,77 @@ def _lazy_init(self): tables_info_dict={'tables': [ALGORITHM_TABLE_INFO]}, ) - LOG.info('[DocumentProcessorWorker._Impl] initialized') + LOG.info(f'{self._log_prefix()} initialized') + + def _get_worker_identity(self) -> str: + env_keys = ('POD_IP', 'POD_NAME', 'HOSTNAME') + for key in env_keys: + value = os.getenv(key) + if value: + return value + try: + ip = subprocess.check_output(['hostname', '-i'], text=True).strip() + if ip: + return ip + except Exception: + pass + return 'worker' + + def _log_prefix(self, task_id: str = None) -> str: + if task_id: + return f'[DocumentProcessorWorker._Impl][worker_id={self._worker_id}][task_id={task_id}]' + return f'[DocumentProcessorWorker._Impl][worker_id={self._worker_id}]' + + def _wrap_task_poller(self, task_poller): + def _impl(): + result = task_poller() + if result is None: + return [] + return result if isinstance(result, list) else [result] + return _impl + + def _start_lease_renewal(self, task_id: str): + if self._lease_renew_interval <= 0: + return + self._lease_stop_event = threading.Event() + + def _renew(): + while not self._lease_stop_event.wait(self._lease_renew_interval): + try: + self._waiting_task_queue.extend_lease(task_id, self._worker_id, self._lease_duration) + except Exception as e: + LOG.warning(f'{self._log_prefix(task_id)} Failed to extend lease: {e}') + + self._lease_thread = threading.Thread(target=_renew, daemon=True) + self._lease_thread.start() + + def _stop_lease_renewal(self): + if self._lease_stop_event is not None: + self._lease_stop_event.set() + if self._lease_thread is not None and self._lease_thread.is_alive(): + self._lease_thread.join(timeout=2.0) + self._lease_thread = None + self._lease_stop_event = None + + def _fail_in_progress_task(self, reason: str): + if not self._in_progress_task: + return + task_id = self._in_progress_task.get('task_id') + task_type = self._in_progress_task.get('task_type') + if task_id and task_type: + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.FAILED, + error_code='PRESTOP', + error_msg=reason, + ) + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete in-progress task') + self._in_progress_task = None @app.get('/health') def get_health(self): @@ -57,7 +146,7 @@ def get_health(self): return BaseResponse(code=503, msg='Worker thread not started') if not self._worker_thread.is_alive(): - LOG.error('[DocumentProcessorWorker._Impl] Worker thread is dead') + LOG.error(f'{self._log_prefix()} Worker thread is dead') return BaseResponse(code=503, msg='Worker thread is not alive') return BaseResponse(code=200, msg='success') @@ -68,16 +157,23 @@ def get_prestop(self): if self._worker_thread is not None and self._worker_thread.is_alive(): self._worker_thread.join(timeout=5.0) if self._worker_thread.is_alive(): - LOG.warning('[DocumentProcessorWorker._Impl] Worker thread did not stop within timeout') + LOG.warning(f'{self._log_prefix()} Worker thread did not stop within timeout') + self._fail_in_progress_task('prestop timeout') else: - LOG.info('[DocumentProcessorWorker._Impl] Worker thread stopped') + LOG.info(f'{self._log_prefix()} Worker thread stopped') + if self._poller_thread is not None and self._poller_thread.is_alive(): + self._poller_thread.join(timeout=5.0) + if self._poller_thread.is_alive(): + LOG.warning(f'{self._log_prefix()} Poller thread did not stop within timeout') + else: + LOG.info(f'{self._log_prefix()} Poller thread stopped') return BaseResponse(code=200, msg='success') def _get_or_create_processor(self, algo_id: str) -> _Processor: try: self._lazy_init() if algo_id in self._processors: - LOG.debug(f'[DocumentProcessorWorker._Impl] Using cached processor for {algo_id}') + LOG.debug(f'{self._log_prefix()} Using cached processor for {algo_id}') return self._processors[algo_id] with self._db_manager.get_session() as session: @@ -96,10 +192,10 @@ def _get_or_create_processor(self, algo_id: str) -> _Processor: processor = _Processor(algo_id, store, reader, node_groups, schema_extractor, display_name, description) self._processors[algo_id] = processor - LOG.info(f'[DocumentProcessorWorker._Impl] Created processor for {algo_id}') + LOG.info(f'{self._log_prefix()} Created processor for {algo_id}') return self._processors[algo_id] except Exception as e: - LOG.warning(f'[DocumentProcessorWorker._Impl] Failed to load algo: {e}') + LOG.warning(f'{self._log_prefix()} Failed to load algo: {e}') raise e def _exec_add_task(self, processor: _Processor, task_id: str, payload: dict): @@ -117,7 +213,7 @@ def _exec_add_task(self, processor: _Processor, task_id: str, payload: dict): processor.add_doc(input_files=input_files, ids=ids, metadatas=metadatas, kb_id=kb_id) except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute add task failed, error: {e}') + LOG.error(f'{self._log_prefix(task_id)} Execute add task failed: {e}') raise e def _exec_reparse_task( @@ -147,9 +243,9 @@ def _exec_reparse_task( doc_paths=reparse_files, metadatas=reparse_metadatas, kb_id=kb_id) except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute reparse task failed, error: {e}') + LOG.error(f'{self._log_prefix(task_id)} Execute reparse task failed: {e}') raise e - + def _exec_transfer_task(self, processor: _Processor, task_id: str, payload: dict): try: file_infos = payload.get('file_infos') @@ -175,7 +271,7 @@ def _exec_transfer_task(self, processor: _Processor, task_id: str, payload: dict transfer_mode=transfer_mode, target_kb_id=target_kb_id, target_doc_ids=target_doc_ids) except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute transfer task failed, error: {e}') + LOG.error(f'{self._log_prefix(task_id)} Execute transfer task failed: {e}') raise e def _exec_delete_task(self, processor: _Processor, task_id: str, payload: dict): @@ -184,7 +280,7 @@ def _exec_delete_task(self, processor: _Processor, task_id: str, payload: dict): doc_ids = payload.get('doc_ids') processor.delete_doc(doc_ids=doc_ids, kb_id=kb_id) except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute delete task failed, error: {e}') + LOG.error(f'{self._log_prefix(task_id)} Execute delete task failed: {e}') raise e def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: dict): @@ -196,10 +292,128 @@ def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: d metadata = file_info.get('metadata') processor.update_doc_meta(doc_id=doc_id, metadata=metadata, kb_id=kb_id) except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Task-{task_id}: execute update meta task failed,' - f'error: {e}') + LOG.error(f'{self._log_prefix(task_id)} Execute update meta task failed: {e}') raise e + def _resolve_task_type(self, request: AddDocRequest) -> str: # noqa C901 + new_file_ids = [] + reparse_file_ids = [] + transfer_mode = None + target_algo_id = None + target_kb_id = None + + for file_info in request.file_infos: + if file_info.reparse_group is not None: + reparse_file_ids.append(file_info.doc_id) + else: + new_file_ids.append(file_info.doc_id) + if file_info.transfer_params: + if target_algo_id is not None and target_algo_id != file_info.transfer_params.target_algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + if target_kb_id is not None and target_kb_id != file_info.transfer_params.target_kb_id: + raise ValueError('transfer_params.target_kb_id must be the same for all files') + if transfer_mode is not None and transfer_mode != file_info.transfer_params.mode: + raise ValueError('transfer_params.mode must be the same for all files') + if file_info.transfer_params.target_algo_id != request.algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + target_algo_id = file_info.transfer_params.target_algo_id + target_kb_id = file_info.transfer_params.target_kb_id + transfer_mode = file_info.transfer_params.mode + if transfer_mode not in ['cp', 'mv']: + raise ValueError('transfer_params.mode must be one of [cp, mv]') + + if new_file_ids and reparse_file_ids: + raise ValueError('new_file_ids and reparse_file_ids cannot be specified at the same time') + if transfer_mode: + return TaskType.DOC_TRANSFER.value + if new_file_ids: + return TaskType.DOC_ADD.value + if reparse_file_ids: + return TaskType.DOC_REPARSE.value + raise ValueError('no input files or reparse group specified') + + def _validate_task_payload(self, task_type: str, payload: dict): + if not isinstance(payload, dict): + raise ValueError('payload must be a dict') + if task_type in ( + TaskType.DOC_ADD.value, + TaskType.DOC_REPARSE.value, + TaskType.DOC_TRANSFER.value, + TaskType.DOC_UPDATE_META.value, + ): + file_infos = payload.get('file_infos') + if not isinstance(file_infos, list) or not file_infos: + raise ValueError(f'file_infos is required for task_type {task_type}') + if task_type == TaskType.DOC_DELETE.value: + doc_ids = payload.get('doc_ids') + if not isinstance(doc_ids, list) or not doc_ids: + raise ValueError('doc_ids is required for task_type DOC_DELETE') + + def _enqueue_task_from_payload(self, task: dict): + try: + task_type = task.get('task_type') + if task_type == TaskType.DOC_DELETE.value: + task_info = DeleteDocRequest(**task) + elif task_type == TaskType.DOC_UPDATE_META.value: + task_info = UpdateMetaRequest(**task) + else: + task_info = AddDocRequest(**task) + task_type = task_type or self._resolve_task_type(task_info) + task_id = task_info.task_id + payload = task_info.model_dump() + self._validate_task_payload(task_type, payload) + user_priority = task_info.priority if task_info.priority is not None else 0 + task_score = _calculate_task_score(task_type, user_priority) + payload_json = json.dumps(payload, ensure_ascii=False) + now = datetime.now() + + self._waiting_task_queue.enqueue( + task_id=task_id, + task_type=task_type, + user_priority=user_priority, + task_score=task_score, + message=payload_json, + status=TaskStatus.WAITING.value, + worker_id=None, + lease_expires_at=None, + created_at=now, + updated_at=now, + ) + LOG.info(f'{self._log_prefix(task_id)} [Poller] task (type={task_type}, ' + f'user_priority={user_priority}, score={task_score}) ' + 'submitted to database queue successfully') + except Exception as e: + LOG.warning(f'{self._log_prefix()} [Poller] Skip invalid task payload: {e}. ' + f'payload={task}') + + def _poller(self): # noqa: C901 + while not self._shutdown: + try: + tasks = self._task_poller_impl() + if not tasks: + time.sleep(0.1) + continue + for task in tasks: + self._enqueue_task_from_payload(task) + except Exception as e: + LOG.error(f'{self._log_prefix()} [Poller] fetch failed: {e}') + time.sleep(WORKER_ERROR_RETRY_INTERVAL) + LOG.info(f'{self._log_prefix()} [Poller] stopped') + + def _poll_task(self): + include_types = None + exclude_types = None + if self._high_priority_task_types and self._high_priority_only: + include_types = list(self._high_priority_task_types) + return self._waiting_task_queue.claim( + worker_id=self._worker_id, + lease_duration=self._lease_duration, + status_waiting=TaskStatus.WAITING.value, + status_working=TaskStatus.WORKING.value, + include_task_types=include_types, + exclude_task_types=exclude_types, + ) + def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: TaskStatus, error_code: str = None, error_msg: str = None): try: @@ -213,33 +427,33 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task error_msg=error_msg if error_msg else 'success' ) if task_status == TaskStatus.FINISHED: - LOG.info(f'[DocumentProcessorWorker._Impl] Task {task_id} finished successfully') + LOG.info(f'{self._log_prefix(task_id)} Task finished successfully') else: - LOG.error(f'[DocumentProcessorWorker._Impl] Task {task_id} completed with status {task_status}:' - f' {error_msg}') + LOG.error(f'{self._log_prefix(task_id)} Task completed with status {task_status}: {error_msg}') except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Failed to enqueue finished task {task_id}: {e}') + LOG.error(f'{self._log_prefix(task_id)} Failed to enqueue finished task: {e}') - def _worker_impl(self): + def _worker_impl(self): # noqa: C901 while not self._shutdown: task_id = None task_type = None try: - task_data = self._waiting_task_queue.dequeue() + task_data = self._poll_task() if not task_data: time.sleep(0.1) continue task_id = task_data['task_id'] task_type = task_data['task_type'] + self._in_progress_task = {'task_id': task_id, 'task_type': task_type} + self._start_lease_renewal(task_id) payload = json.loads(task_data.get('message')) algo_id = payload.get('algo_id') if not algo_id: - raise ValueError(f'[DocumentProcessorWorker._Impl] task_id {task_id} is missing algo_id in ' - f'payload: {payload}') + raise ValueError(f'{self._log_prefix(task_id)} task_id is missing algo_id in payload: {payload}') - LOG.info(f'[DocumentProcessorWorker._Impl] Start processing task {task_id}, type: {task_type},' - f' algo_id: {algo_id}') + LOG.info(f'{self._log_prefix(task_id)} Start processing task, type: {task_type}, ' + f'algo_id: {algo_id}') processor = self._get_or_create_processor(algo_id) if task_type == TaskType.DOC_ADD.value: @@ -253,46 +467,78 @@ def _worker_impl(self): elif task_type == TaskType.DOC_TRANSFER.value: self._exec_transfer_task(processor, task_id, payload) else: - raise ValueError(f'[DocumentProcessorWorker._Impl] Unknown task type: {task_type}') + raise ValueError(f'{self._log_prefix(task_id)} Unknown task type: {task_type}') self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, error_code='200', error_msg='success') + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete finished task') except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Failed to run task {task_id}: {e},' - f' {traceback.format_exc()}') + LOG.error(f'{self._log_prefix(task_id)} Failed to run task: {e}, {traceback.format_exc()}') if task_id and task_type: self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, error_code=type(e).__name__, error_msg=str(e)) + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete failed task') time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue + finally: + self._stop_lease_renewal() + self._in_progress_task = None def start(self): - LOG.info('[DocumentProcessorWorker._Impl] Starting worker...') + LOG.info(f'{self._log_prefix()} Starting worker...') self._lazy_init() if self._worker_thread is not None and self._worker_thread.is_alive(): - LOG.warning('[DocumentProcessorWorker._Impl] Worker thread is already running') + LOG.warning(f'{self._log_prefix()} Worker thread is already running') return self._shutdown = False + if self._task_poller_impl is not None: + if self._poller_thread is None or not self._poller_thread.is_alive(): + self._poller_thread = threading.Thread(target=self._poller, daemon=True) + self._poller_thread.start() self._worker_thread = threading.Thread(target=self._worker_impl, daemon=True) self._worker_thread.start() - LOG.info('[DocumentProcessorWorker._Impl] Worker thread started') + LOG.info(f'{self._log_prefix()} Worker thread started') def shutdown(self): - LOG.info('[DocumentProcessorWorker._Impl] Shutting down worker...') + LOG.info(f'{self._log_prefix()} Shutting down worker...') self._shutdown = True if self._worker_thread is not None and self._worker_thread.is_alive(): self._worker_thread.join(timeout=5.0) if self._worker_thread.is_alive(): - LOG.warning('[DocumentProcessorWorker._Impl] Worker thread did not stop within timeout') + LOG.warning(f'{self._log_prefix()} Worker thread did not stop within timeout') + self._fail_in_progress_task('shutdown timeout') + else: + LOG.info(f'{self._log_prefix()} Worker thread stopped') + if self._poller_thread is not None and self._poller_thread.is_alive(): + self._poller_thread.join(timeout=5.0) + if self._poller_thread.is_alive(): + LOG.warning(f'{self._log_prefix()} Poller thread did not stop within timeout') else: - LOG.info('[DocumentProcessorWorker._Impl] Worker thread stopped') + LOG.info(f'{self._log_prefix()} Poller thread stopped') - def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None): + def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None, + task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, + high_priority_task_types: list[str] = None, high_priority_only: bool = False): super().__init__() self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._num_workers = num_workers self._port = port - worker_impl = DocumentProcessorWorker._Impl(db_config=self._db_config) + worker_impl = DocumentProcessorWorker._Impl( + db_config=self._db_config, + task_poller=task_poller, + lease_duration=lease_duration, + lease_renew_interval=lease_renew_interval, + high_priority_task_types=high_priority_task_types, + high_priority_only=high_priority_only, + ) self._worker_impl = ServerModule(worker_impl, port=self._port, num_replicas=self._num_workers) LOG.info(f'[DocumentProcessorWorker] Worker initialized with {num_workers} workers') diff --git a/lazyllm/tools/rag/store/utils.py b/lazyllm/tools/rag/store/utils.py index c2666ddc3..2986b1276 100644 --- a/lazyllm/tools/rag/store/utils.py +++ b/lazyllm/tools/rag/store/utils.py @@ -190,36 +190,36 @@ def presign_obj_from_s3( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, endpoint_url: Optional[str] = None, - region_name: str = "us-east-1", - client_method: str = "get_object", + region_name: str = 'us-east-1', + client_method: str = 'get_object', expires_in: int = 3600, extra_params: Optional[Dict[str, Any]] = None, ) -> str: - spec = importlib.util.find_spec("botocore.client") + spec = importlib.util.find_spec('botocore.client') if spec is None: raise ImportError( - "Please install boto3 to use botocore module. " - "You can install it with `pip install boto3`" + 'Please install boto3 to use botocore module. ' + 'You can install it with `pip install boto3`' ) m = importlib.util.module_from_spec(spec) spec.loader.exec_module(m) Config = m.Config s3_client = boto3.client( - "s3", + 's3', region_name=region_name, endpoint_url=endpoint_url, aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, config=Config( - signature_version="s3v4", + signature_version='s3v4', ), ) params = { - "Bucket": bucket_name, - "Key": object_key, + 'Bucket': bucket_name, + 'Key': object_key, } if extra_params: @@ -233,7 +233,7 @@ def presign_obj_from_s3( ) return url except Exception as e: - LOG.error(f"Generate presigned url failed: {e}") + LOG.error(f'Generate presigned url failed: {e}') raise e def fibonacci_backoff(max_retries: int = INSERT_MAX_RETRIES): diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index da4168054..d0c7f53e1 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -106,8 +106,7 @@ def _create_tables_by_info(self, tables_info: TablesInfo): column_name = column_info.name is_primary = column_info.is_primary_key default_value = column_info.default - # Use text for unsupported column type - real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text) + real_type = self._sql_type_for(column_type) # Handle default value if default_value is not None: attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, From 8c76b157d4771fb3e5820209a66504a5f028fa81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 21 Jan 2026 10:55:23 +0800 Subject: [PATCH 03/46] add doc and modify bt review --- lazyllm/docs/tools.py | 44 +++++++++++++++ lazyllm/tools/rag/doc_node.py | 4 +- lazyllm/tools/rag/parsing_service/base.py | 38 +++++++++++++ lazyllm/tools/rag/parsing_service/impl.py | 58 ++++++++++--------- lazyllm/tools/rag/parsing_service/server.py | 62 ++------------------- lazyllm/tools/rag/parsing_service/worker.py | 40 +------------ 6 files changed, 124 insertions(+), 122 deletions(-) diff --git a/lazyllm/docs/tools.py b/lazyllm/docs/tools.py index 5c08357f9..13213fa71 100644 --- a/lazyllm/docs/tools.py +++ b/lazyllm/docs/tools.py @@ -3678,6 +3678,32 @@ def my_reranker(node: DocNode, **kwargs): None """) +add_chinese_doc('rag.doc_node.DocNode.copy', """ +复制当前 DocNode,并生成新的 uid。 + +复制后的节点会记录来源信息(_copy_source),可选更新 metadata/global_metadata。 + +Args: + global_metadata (dict): 需要合并到 global_metadata 的字段 + metadata (dict): 需要合并到 metadata 的字段 + +Returns: + DocNode: 复制后的节点 +""") + +add_english_doc('rag.doc_node.DocNode.copy', """ +Copy the current DocNode and generate a new uid. + +The copied node records its source (_copy_source) and can optionally merge metadata/global_metadata. + +Args: + global_metadata (dict): Fields to merge into global_metadata + metadata (dict): Fields to merge into metadata + +Returns: + DocNode: The copied node +""") + add_chinese_doc('rag.parsing_service.server.DocumentProcessor', """ 文档处理服务类,启动后可对外提供文档处理服务,支持文档的添加、删除和更新等操作。 服务内部采取生产者-消费者模式,通过队列管理文档处理任务,支持异步处理文档任务,支持任务状态回调通知。 @@ -3692,6 +3718,10 @@ def my_reranker(node: DocNode, **kwargs): def post_func(task_id: str, task_status: str = None, error_code: str = None, error_msg: str = None): pass path_prefix (Optional[str]): 用于配置上传文件存储路径前缀,默认为None。 + lease_duration (float): 任务租约时长(秒),默认为300。 + lease_renew_interval (float): 租约续租间隔(秒),默认为60。 + high_priority_task_types (Optional[List[str]]): 高优任务类型列表,默认包含 DOC_DELETE。 + high_priority_workers (int): 高优任务 worker 数量,默认1。 """) add_english_doc('rag.parsing_service.server.DocumentProcessor', """ @@ -3708,6 +3738,10 @@ def post_func(task_id: str, task_status: str = None, error_code: str = None, err def post_func(task_id: str, task_status: str = None, error_code: str = None, error_msg: str = None): pass path_prefix (Optional[str]): Used to configure the prefix of the uploaded file storage path, defaults to None. + lease_duration (float): Task lease duration in seconds, defaults to 300. + lease_renew_interval (float): Lease renewal interval in seconds, defaults to 60. + high_priority_task_types (Optional[List[str]]): High priority task types, defaults to [DOC_DELETE]. + high_priority_workers (int): Number of high priority workers, defaults to 1. """) add_example('rag.parsing_service.server.DocumentProcessor', """ @@ -3798,6 +3832,11 @@ def post_func(task_id: str, task_status: str = None, error_code: str = None, err db_config (Optional[Dict[str, Any]]): 用于配置SqlManager实现数据库连接,默认为None,当为None时,使用默认数据库配置。 num_workers (int): 工作线程数,默认为1, 当大于1时,内部基于ray集群启动多个工作线程,否则仅启动一个工作线程。 port (Optional[int]): 服务端口号。默认为None,当为None时,将自动分配端口。 + task_poller (Optional[Callable]): 外部任务拉取函数,可选。 + lease_duration (float): 任务租约时长(秒),默认为300。 + lease_renew_interval (float): 租约续租间隔(秒),默认为60。 + high_priority_task_types (Optional[List[str]]): 高优任务类型列表,可选。 + high_priority_only (bool): 仅处理高优任务,默认为False。 ''') add_english_doc('rag.parsing_service.worker.DocumentProcessorWorker', ''' @@ -3808,6 +3847,11 @@ def post_func(task_id: str, task_status: str = None, error_code: str = None, err db_config (Optional[Dict[str, Any]]): Used to configure the database connection information for SqlManager, defaults to None, when it is None, the default database configuration is used. num_workers (int): Number of worker threads, defaults to 1, when it is greater than 1, multiple worker threads are started internally based on the ray cluster, otherwise only one worker thread is started. port (Optional[int]): Service port number. Defaults to None, when it is None, a random port will be assigned. + task_poller (Optional[Callable]): External task poller callback, optional. + lease_duration (float): Task lease duration in seconds, defaults to 300. + lease_renew_interval (float): Lease renewal interval in seconds, defaults to 60. + high_priority_task_types (Optional[List[str]]): High priority task types, optional. + high_priority_only (bool): Process high priority tasks only, defaults to False. ''') add_chinese_doc('rag.parsing_service.worker.DocumentProcessorWorker.start', ''' diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 3dc0a1f95..0cdd420a9 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -287,9 +287,9 @@ def copy(self, global_metadata: dict = None, metadata: dict = None) -> 'DocNode' RAG_DOC_ID: self.global_metadata.get(RAG_DOC_ID)} node._uid = str(uuid.uuid4()) if metadata: - node.metadata = node.metadata.update(metadata) + node.metadata.update(metadata) if global_metadata: - node.global_metadata = node.global_metadata.update(global_metadata) + node.global_metadata.update(global_metadata) return node def with_score(self, score): diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index d7274d05e..486b67c3f 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -103,6 +103,44 @@ def _calculate_task_score(task_type: str, user_priority: int) -> int: return type_weight * 10 - user_priority * 15 +def _resolve_add_doc_task_type(request: AddDocRequest) -> str: # noqa: C901 + new_file_ids = [] + reparse_file_ids = [] + transfer_mode = None + target_algo_id = None + target_kb_id = None + + for file_info in request.file_infos: + if file_info.reparse_group is not None: + reparse_file_ids.append(file_info.doc_id) + else: + new_file_ids.append(file_info.doc_id) + if file_info.transfer_params: + if target_algo_id is not None and target_algo_id != file_info.transfer_params.target_algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + if target_kb_id is not None and target_kb_id != file_info.transfer_params.target_kb_id: + raise ValueError('transfer_params.target_kb_id must be the same for all files') + if transfer_mode is not None and transfer_mode != file_info.transfer_params.mode: + raise ValueError('transfer_params.mode must be the same for all files') + if file_info.transfer_params.target_algo_id != request.algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + target_algo_id = file_info.transfer_params.target_algo_id + target_kb_id = file_info.transfer_params.target_kb_id + transfer_mode = file_info.transfer_params.mode + if transfer_mode not in ['cp', 'mv']: + raise ValueError('transfer_params.mode must be one of [cp, mv]') + + if new_file_ids and reparse_file_ids: + raise ValueError('new_file_ids and reparse_file_ids cannot be specified at the same time') + if transfer_mode: + return TaskType.DOC_TRANSFER.value + if new_file_ids: + return TaskType.DOC_ADD.value + if reparse_file_ids: + return TaskType.DOC_REPARSE.value + raise ValueError('no input files or reparse group specified') + + # Waiting task queue table WAITING_TASK_QUEUE_TABLE_INFO = { 'name': 'lazyllm_waiting_task_queue', diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index 396c46cac..a49ffcab1 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -100,32 +100,9 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no else: self._store.update_nodes(root_nodes, copy=True) root_uid_map = {n._copy_source.get('uid'): n.uid for n in root_nodes} - - def _copy_segments_recursive(p_uid_map: dict, p_name: str): - for group_name in self._store.activated_groups(): - group = self._node_groups.get(group_name) - if group is None: - raise ValueError(f'Node group {group_name} does not exist. Please check the group name ' - 'or add a new one through `create_node_group`.') - if group['parent'] == p_name: - nodes = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) - nodes = [ - n.copy( - global_metadata={ - RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[ - n.global_metadata[RAG_DOC_ID] - ][0] - }, - metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] - ) for n in nodes - ] - uid_map = {} - for n in nodes: - uid_map[n._copy_source.get('uid')] = n.uid - n.parent = p_uid_map.get(n.parent, None) if n.parent else None - self._store.update_nodes(nodes, copy=True) - if nodes: _copy_segments_recursive(uid_map, group_name) - _copy_segments_recursive(p_uid_map=root_uid_map, p_name=LAZY_ROOT_NAME) + self._copy_segments_recursive(ids=ids, kb_id=kb_id, target_kb_id=target_kb_id, + doc_id_map=doc_id_map, p_uid_map=root_uid_map, + p_name=LAZY_ROOT_NAME) for future in schema_futures: try: @@ -137,7 +114,7 @@ def _copy_segments_recursive(p_uid_map: dict, p_name: str): raise schema_errors[0] add_time = time.time() - add_start LOG.info(f'[_Processor - add_doc] Add documents done! files:{input_files}, ' - f'Total Time: {load_time}s, Load Time: {add_time}s') + f'Total Time: {add_time}s, Data Loading Time: {load_time}s') except Exception as e: LOG.error(f'Add documents failed: {e}, {traceback.format_exc()}') raise e @@ -170,6 +147,33 @@ def _create_nodes_recursive(self, p_nodes: List[DocNode], p_name: str): nodes = self._create_nodes_impl(p_nodes, group_name) if nodes: self._create_nodes_recursive(nodes, group_name) + def _copy_segments_recursive(self, ids: List[str], kb_id: str, target_kb_id: str, doc_id_map: Dict[str, tuple], + p_uid_map: dict, p_name: str): + for group_name in self._store.activated_groups(): + group = self._node_groups.get(group_name) + if group is None: + raise ValueError(f'Node group {group_name} does not exist. Please check the group name ' + 'or add a new one through `create_node_group`.') + if group['parent'] == p_name: + nodes = self._store.get_nodes(doc_ids=ids, group=group_name, kb_id=kb_id) + nodes = [ + n.copy( + global_metadata={ + RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[n.global_metadata[RAG_DOC_ID]][0] + }, + metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] + ) for n in nodes + ] + uid_map = {} + for n in nodes: + uid_map[n._copy_source.get('uid')] = n.uid + n.parent = p_uid_map.get(n.parent, None) if n.parent else None + self._store.update_nodes(nodes, copy=True) + if nodes: + self._copy_segments_recursive(ids=ids, kb_id=kb_id, target_kb_id=target_kb_id, + doc_id_map=doc_id_map, p_uid_map=uid_map, + p_name=group_name) + def _create_nodes_impl(self, p_nodes, group_name): # NOTE transform.batch_forward will set children for p_nodes, but when calling # transform.batch_forward, p_nodes has been upsert in the store. diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index ce05f2a0a..f35af76ac 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -15,7 +15,7 @@ from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, - _calculate_task_score + _calculate_task_score, _resolve_add_doc_task_type ) from .worker import DocumentProcessorWorker as Worker from .queue import _SQLBasedQueue as Queue @@ -311,65 +311,15 @@ def add_doc(self, request: AddDocRequest): # noqa: C901 if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') # NOTE: No idempotency key check, should be handled by the caller! - new_file_ids = [] - reparse_file_ids = [] - - transfer_mode = None - target_algo_id = None - target_kb_id = None - for file_info in file_infos: parse_file_path = file_info.transformed_file_path if \ file_info.transformed_file_path else file_info.file_path file_info.file_path = create_file_path(parse_file_path, prefix=self._path_prefix) - if file_info.reparse_group is not None: - reparse_file_ids.append(file_info.doc_id) - else: - new_file_ids.append(file_info.doc_id) - if file_info.transfer_params: - if target_algo_id is not None and target_algo_id != file_info.transfer_params.target_algo_id: - raise fastapi.HTTPException( - status_code=400, - detail='transfer_params.target_algo_id must be the same for all files' - ) - if target_kb_id is not None and target_kb_id != file_info.transfer_params.target_kb_id: - raise fastapi.HTTPException( - status_code=400, - detail='transfer_params.target_kb_id must be the same for all files' - ) - if transfer_mode is not None and transfer_mode != file_info.transfer_params.mode: - raise fastapi.HTTPException( - status_code=400, - detail='transfer_params.mode must be the same for all files' - ) - # NOTE: currently we only support file transfer in the same algorithm - if file_info.transfer_params.target_algo_id != algo_id: - raise fastapi.HTTPException( - status_code=400, - detail='transfer_params.target_algo_id must be the same for all files' - ) - target_algo_id = file_info.transfer_params.target_algo_id - target_kb_id = file_info.transfer_params.target_kb_id - transfer_mode = file_info.transfer_params.mode - if transfer_mode not in ['cp', 'mv']: - raise fastapi.HTTPException( - status_code=400, - detail='transfer_params.mode must be one of [cp, mv]' - ) - - if new_file_ids and reparse_file_ids: - raise fastapi.HTTPException( - status_code=400, - detail='new_file_ids and reparse_file_ids cannot be specified at the same time' - ) - if transfer_mode: - task_type = TaskType.DOC_TRANSFER.value - elif new_file_ids: - task_type = TaskType.DOC_ADD.value - elif reparse_file_ids: - task_type = TaskType.DOC_REPARSE.value - else: - raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + + try: + task_type = _resolve_add_doc_task_type(request) + except ValueError as e: + raise fastapi.HTTPException(status_code=400, detail=str(e)) payload = request.model_dump() LOG.info(f'[DocumentProcessor] Received add doc request: {payload}') payload_json = json.dumps(payload, ensure_ascii=False) diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index e2d887101..e4619b07e 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -13,7 +13,7 @@ from .base import ( FINISHED_TASK_QUEUE_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, TaskStatus, TaskType, ALGORITHM_TABLE_INFO, AddDocRequest, UpdateMetaRequest, - DeleteDocRequest, _calculate_task_score + DeleteDocRequest, _calculate_task_score, _resolve_add_doc_task_type ) from .impl import _Processor from .queue import _SQLBasedQueue as Queue @@ -295,42 +295,8 @@ def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: d LOG.error(f'{self._log_prefix(task_id)} Execute update meta task failed: {e}') raise e - def _resolve_task_type(self, request: AddDocRequest) -> str: # noqa C901 - new_file_ids = [] - reparse_file_ids = [] - transfer_mode = None - target_algo_id = None - target_kb_id = None - - for file_info in request.file_infos: - if file_info.reparse_group is not None: - reparse_file_ids.append(file_info.doc_id) - else: - new_file_ids.append(file_info.doc_id) - if file_info.transfer_params: - if target_algo_id is not None and target_algo_id != file_info.transfer_params.target_algo_id: - raise ValueError('transfer_params.target_algo_id must be the same for all files') - if target_kb_id is not None and target_kb_id != file_info.transfer_params.target_kb_id: - raise ValueError('transfer_params.target_kb_id must be the same for all files') - if transfer_mode is not None and transfer_mode != file_info.transfer_params.mode: - raise ValueError('transfer_params.mode must be the same for all files') - if file_info.transfer_params.target_algo_id != request.algo_id: - raise ValueError('transfer_params.target_algo_id must be the same for all files') - target_algo_id = file_info.transfer_params.target_algo_id - target_kb_id = file_info.transfer_params.target_kb_id - transfer_mode = file_info.transfer_params.mode - if transfer_mode not in ['cp', 'mv']: - raise ValueError('transfer_params.mode must be one of [cp, mv]') - - if new_file_ids and reparse_file_ids: - raise ValueError('new_file_ids and reparse_file_ids cannot be specified at the same time') - if transfer_mode: - return TaskType.DOC_TRANSFER.value - if new_file_ids: - return TaskType.DOC_ADD.value - if reparse_file_ids: - return TaskType.DOC_REPARSE.value - raise ValueError('no input files or reparse group specified') + def _resolve_task_type(self, request: AddDocRequest) -> str: + return _resolve_add_doc_task_type(request) def _validate_task_payload(self, task_type: str, payload: dict): if not isinstance(payload, dict): From f82f0dd5dbf2b0f6c82de7e0cdfbbfb6171ef1bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Fri, 30 Jan 2026 16:14:29 +0800 Subject: [PATCH 04/46] modify code --- lazyllm/docs/tools.py | 6 + lazyllm/tools/rag/parsing_service/worker.py | 178 ++++++++++++++------ 2 files changed, 130 insertions(+), 54 deletions(-) diff --git a/lazyllm/docs/tools.py b/lazyllm/docs/tools.py index 075a365a2..e1f96487f 100644 --- a/lazyllm/docs/tools.py +++ b/lazyllm/docs/tools.py @@ -4025,6 +4025,9 @@ def post_func(task_id: str, task_status: str = None, error_code: str = None, err num_workers (int): 工作线程数,默认为1, 当大于1时,内部基于ray集群启动多个工作线程,否则仅启动一个工作线程。 port (Optional[int]): 服务端口号。默认为None,当为None时,将自动分配端口。 task_poller (Optional[Callable]): 外部任务拉取函数,可选。 + poll_mode (str): 任务拉取模式,可选值为 "direct" 或 "thread"。 + - "direct": 不启动独立 poller 线程,worker 空闲时直接拉取并立即处理任务(默认)。 + - "thread": 启动独立 poller 线程,持续拉取任务并入队。 lease_duration (float): 任务租约时长(秒),默认为300。 lease_renew_interval (float): 租约续租间隔(秒),默认为60。 high_priority_task_types (Optional[List[str]]): 高优任务类型列表,可选。 @@ -4040,6 +4043,9 @@ def post_func(task_id: str, task_status: str = None, error_code: str = None, err num_workers (int): Number of worker threads, defaults to 1, when it is greater than 1, multiple worker threads are started internally based on the ray cluster, otherwise only one worker thread is started. port (Optional[int]): Service port number. Defaults to None, when it is None, a random port will be assigned. task_poller (Optional[Callable]): External task poller callback, optional. + poll_mode (str): Task polling mode, either "direct" or "thread". + - "direct": No dedicated poller thread; the worker pulls and processes tasks when idle (default). + - "thread": Run a dedicated poller thread to continuously fetch tasks and enqueue them. lease_duration (float): Task lease duration in seconds, defaults to 300. lease_renew_interval (float): Lease renewal interval in seconds, defaults to 60. high_priority_task_types (Optional[List[str]]): High priority task types, optional. diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index e4619b07e..f4795bffa 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -27,7 +27,7 @@ class DocumentProcessorWorker(ModuleBase): class _Impl(): def __init__(self, db_config: dict = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, - high_priority_only: bool = False): + high_priority_only: bool = False, poll_mode: str = 'direct'): self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._shutdown = False self._processors: dict[str, _Processor] = {} # algo_id -> _Processor @@ -39,6 +39,9 @@ def __init__(self, db_config: dict = None, task_poller=None, lease_duration: flo raise TypeError('task_poller is not callable') self._task_poller = task_poller self._task_poller_impl = self._wrap_task_poller(task_poller) if task_poller else None + if poll_mode not in ('direct', 'thread'): + raise ValueError('poll_mode must be one of ["direct", "thread"]') + self._poll_mode = poll_mode self._worker_id = f'{self._get_worker_identity()}-{uuid4()}' self._in_progress_task = None self._lease_thread = None @@ -352,6 +355,71 @@ def _enqueue_task_from_payload(self, task: dict): LOG.warning(f'{self._log_prefix()} [Poller] Skip invalid task payload: {e}. ' f'payload={task}') + def _parse_task_payload(self, task: dict): + task_type = task.get('task_type') + if task_type == TaskType.DOC_DELETE.value: + task_info = DeleteDocRequest(**task) + elif task_type == TaskType.DOC_UPDATE_META.value: + task_info = UpdateMetaRequest(**task) + else: + task_info = AddDocRequest(**task) + task_type = task_type or self._resolve_task_type(task_info) + task_id = task_info.task_id + payload = task_info.model_dump() + self._validate_task_payload(task_type, payload) + return task_id, task_type, payload + + def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: bool): + try: + self._in_progress_task = {'task_id': task_id, 'task_type': task_type} + if from_queue: + self._start_lease_renewal(task_id) + algo_id = payload.get('algo_id') + if not algo_id: + raise ValueError(f'{self._log_prefix(task_id)} task_id is missing algo_id in payload: {payload}') + + LOG.info(f'{self._log_prefix(task_id)} Start processing task, type: {task_type}, ' + f'algo_id: {algo_id}') + + processor = self._get_or_create_processor(algo_id) + if task_type == TaskType.DOC_ADD.value: + self._exec_add_task(processor, task_id, payload) + elif task_type == TaskType.DOC_REPARSE.value: + self._exec_reparse_task(processor, task_id, payload) + elif task_type == TaskType.DOC_DELETE.value: + self._exec_delete_task(processor, task_id, payload) + elif task_type == TaskType.DOC_UPDATE_META.value: + self._exec_update_meta_task(processor, task_id, payload) + elif task_type == TaskType.DOC_TRANSFER.value: + self._exec_transfer_task(processor, task_id, payload) + else: + raise ValueError(f'{self._log_prefix(task_id)} Unknown task type: {task_type}') + + self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, + error_code='200', error_msg='success') + if from_queue: + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete finished task') + except Exception as e: + LOG.error(f'{self._log_prefix(task_id)} Failed to run task: {e}, {traceback.format_exc()}') + if task_id and task_type: + self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, + error_code=type(e).__name__, error_msg=str(e)) + if from_queue: + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete failed task') + time.sleep(WORKER_ERROR_RETRY_INTERVAL) + finally: + if from_queue: + self._stop_lease_renewal() + self._in_progress_task = None + def _poller(self): # noqa: C901 while not self._shutdown: try: @@ -401,62 +469,62 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task def _worker_impl(self): # noqa: C901 while not self._shutdown: - task_id = None - task_type = None try: task_data = self._poll_task() - if not task_data: - time.sleep(0.1) - continue - - task_id = task_data['task_id'] - task_type = task_data['task_type'] - self._in_progress_task = {'task_id': task_id, 'task_type': task_type} - self._start_lease_renewal(task_id) - payload = json.loads(task_data.get('message')) - algo_id = payload.get('algo_id') - if not algo_id: - raise ValueError(f'{self._log_prefix(task_id)} task_id is missing algo_id in payload: {payload}') - - LOG.info(f'{self._log_prefix(task_id)} Start processing task, type: {task_type}, ' - f'algo_id: {algo_id}') - - processor = self._get_or_create_processor(algo_id) - if task_type == TaskType.DOC_ADD.value: - self._exec_add_task(processor, task_id, payload) - elif task_type == TaskType.DOC_REPARSE.value: - self._exec_reparse_task(processor, task_id, payload) - elif task_type == TaskType.DOC_DELETE.value: - self._exec_delete_task(processor, task_id, payload) - elif task_type == TaskType.DOC_UPDATE_META.value: - self._exec_update_meta_task(processor, task_id, payload) - elif task_type == TaskType.DOC_TRANSFER.value: - self._exec_transfer_task(processor, task_id, payload) - else: - raise ValueError(f'{self._log_prefix(task_id)} Unknown task type: {task_type}') - - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, - error_code='200', error_msg='success') - deleted = self._waiting_task_queue.delete( - filter_by={'task_id': task_id, 'worker_id': self._worker_id} - ) - if deleted == 0: - LOG.warning(f'{self._log_prefix(task_id)} Failed to delete finished task') except Exception as e: - LOG.error(f'{self._log_prefix(task_id)} Failed to run task: {e}, {traceback.format_exc()}') - if task_id and task_type: - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, - error_code=type(e).__name__, error_msg=str(e)) - deleted = self._waiting_task_queue.delete( - filter_by={'task_id': task_id, 'worker_id': self._worker_id} - ) - if deleted == 0: - LOG.warning(f'{self._log_prefix(task_id)} Failed to delete failed task') + LOG.error(f'{self._log_prefix()} [Worker] poll_task failed: {e}, {traceback.format_exc()}') time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue - finally: - self._stop_lease_renewal() - self._in_progress_task = None + if task_data: + try: + payload = json.loads(task_data.get('message')) + except Exception as e: + task_id = task_data.get('task_id') + task_type = task_data.get('task_type') + LOG.error(f'{self._log_prefix(task_id)} [Worker] Failed to parse task payload: {e}, ' + f'{traceback.format_exc()}') + if task_id and task_type: + try: + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.FAILED, + error_code=type(e).__name__, + error_msg=str(e), + ) + deleted = self._waiting_task_queue.delete( + filter_by={'task_id': task_id, 'worker_id': self._worker_id} + ) + if deleted == 0: + LOG.warning(f'{self._log_prefix(task_id)} Failed to delete invalid task') + except Exception as inner_e: + LOG.error(f'{self._log_prefix(task_id)} Failed to cleanup invalid task: {inner_e}, ' + f'{traceback.format_exc()}') + time.sleep(WORKER_ERROR_RETRY_INTERVAL) + continue + self._run_task(task_data['task_id'], task_data['task_type'], payload, from_queue=True) + continue + + if self._task_poller_impl is not None and self._poll_mode == 'direct': + try: + tasks = self._task_poller_impl() + if not tasks: + time.sleep(0.1) + continue + for task in tasks: + try: + task_id, task_type, payload = self._parse_task_payload(task) + except Exception as e: + LOG.warning(f'{self._log_prefix()} [Poller] Skip invalid task payload: {e}. ' + f'payload={task}') + continue + self._run_task(task_id, task_type, payload, from_queue=False) + except Exception as e: + LOG.error(f'{self._log_prefix()} [Poller] fetch failed: {e}') + time.sleep(WORKER_ERROR_RETRY_INTERVAL) + continue + + time.sleep(0.1) def start(self): LOG.info(f'{self._log_prefix()} Starting worker...') @@ -465,7 +533,7 @@ def start(self): LOG.warning(f'{self._log_prefix()} Worker thread is already running') return self._shutdown = False - if self._task_poller_impl is not None: + if self._task_poller_impl is not None and self._poll_mode == 'thread': if self._poller_thread is None or not self._poller_thread.is_alive(): self._poller_thread = threading.Thread(target=self._poller, daemon=True) self._poller_thread.start() @@ -492,7 +560,8 @@ def shutdown(self): def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, - high_priority_task_types: list[str] = None, high_priority_only: bool = False): + high_priority_task_types: list[str] = None, high_priority_only: bool = False, + poll_mode: str = 'direct'): super().__init__() self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._num_workers = num_workers @@ -504,6 +573,7 @@ def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = Non lease_renew_interval=lease_renew_interval, high_priority_task_types=high_priority_task_types, high_priority_only=high_priority_only, + poll_mode=poll_mode, ) self._worker_impl = ServerModule(worker_impl, port=self._port, num_replicas=self._num_workers) LOG.info(f'[DocumentProcessorWorker] Worker initialized with {num_workers} workers') From e772a42340540e98d2f9f168fd07f71d876192cb Mon Sep 17 00:00:00 2001 From: wangzhihong Date: Wed, 4 Feb 2026 16:02:14 +0800 Subject: [PATCH 05/46] fix schema_extractor bug --- lazyllm/tools/rag/__init__.py | 3 +- lazyllm/tools/rag/doc_to_db/extractor.py | 8 ++--- lazyllm/tools/rag/document.py | 36 ++++++++++++----------- lazyllm/tools/rag/parsing_service/impl.py | 10 +++---- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index ec4489dab..fa1f59bab 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -11,7 +11,7 @@ CharacterSplitter, RecursiveSplitter, MarkdownSplitter, CodeSplitter, JSONSplitter, YAMLSplitter, HTMLSplitter, XMLSplitter, GeneralCodeSplitter, JSONLSplitter) from .similarity import register_similarity -from .doc_node import DocNode +from .doc_node import DocNode, RichDocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, MineruPDFReader) @@ -46,6 +46,7 @@ 'register_similarity', 'register_reranker', 'DocNode', + 'RichDocNode', 'PDFReader', 'DocxReader', 'HWPReader', diff --git a/lazyllm/tools/rag/doc_to_db/extractor.py b/lazyllm/tools/rag/doc_to_db/extractor.py index 089692250..8f672d486 100644 --- a/lazyllm/tools/rag/doc_to_db/extractor.py +++ b/lazyllm/tools/rag/doc_to_db/extractor.py @@ -13,7 +13,7 @@ from lazyllm import LOG, ThreadPoolExecutor, once_wrapper from lazyllm.components import JsonFormatter -from lazyllm.module import LLMBase +from lazyllm.module import LLMBase, ModuleBase from ...sql.sql_manager import DBStatus, SqlManager from ..doc_node import DocNode @@ -33,7 +33,7 @@ ONE_DOC_LENGTH_LIMIT = 102400 -class SchemaExtractor: +class SchemaExtractor(ModuleBase): '''Schema aware extractor that materializes BaseModel schemas into database tables.''' TABLE_PREFIX = 'lazyllm_schema' @@ -724,8 +724,8 @@ def _get_extract_data(self, algo_id: str, doc_ids: List[str], # noqa: C901 results.append(ExtractResult(data=row_data, metadata=meta)) return results - def __call__(self, data: Union[str, List[DocNode]], - algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult: + def forward(self, data: Union[str, List[DocNode]], + algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult: # NOTE: data should be from single file source (kb_id, doc_id should be the same) self._lazy_init() res = self.extract_and_store(data=data, algo_id=algo_id) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 9102400bf..36119fbbe 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -58,6 +58,7 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._dataset_path = dataset_path self._embed = self._get_embeds(embed) self._processor = processor + self._schema_extractor = self._register_submodules(schema_extractor) name = name or DocListManager.DEFAULT_GROUP_NAME if not display_name: display_name = name @@ -90,21 +91,22 @@ def web_url(self): def _get_embeds(self, embed): embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {} - for embed in embeds.values(): - if isinstance(embed, ModuleBase): - self._submodules.append(embed) - return embeds - - def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, - store_conf: Optional[Dict] = None, - embed: Optional[Union[Callable, Dict[str, Callable]]] = None): + return self._register_submodules(embeds) + + def _register_submodules(self, m): + if not m: return m + for embed in (m.values() if isinstance(m, dict) else m if isinstance(m, (tuple, list)) else [m]): + if isinstance(embed, ModuleBase): self._submodules.append(embed) + return m + + def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, store_conf: Optional[Dict] = None, + embed: Optional[Union[Callable, Dict[str, Callable]]] = None, + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): embed = self._get_embeds(embed) if embed else self._embed - if isinstance(self._kbs, ServerModule): - self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, - global_metadata_desc=doc_fields, store=store_conf) - else: - self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, - global_metadata_desc=doc_fields, store=store_conf) + schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor + impl = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, global_metadata_desc=doc_fields, + store=store_conf, schema_extractor=schema_extractor) + (self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl self._dlm.add_kb_group(name=name) def get_doc_by_kb_group(self, name): @@ -147,7 +149,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal 'Only map store is supported for Document with temp-files') name = name or DocListManager.DEFAULT_GROUP_NAME - self._schema_extractor: SchemaExtractor = schema_extractor if isinstance(manager, Document._Manager): assert not server, 'Server infomation is already set to by manager' @@ -157,7 +158,8 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal if dataset_path != manager._dataset_path and dataset_path != manager._origin_path: raise RuntimeError(f'Document path mismatch, expected `{manager._dataset_path}`' f'while received `{dataset_path}`') - manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed) + manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed, + schema_extractor=schema_extractor) self._manager = manager self._curr_group = name else: @@ -173,7 +175,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf, doc_fields, cloud=cloud, doc_files=doc_files, processor=processor, display_name=display_name, description=description, - schema_extractor=self._schema_extractor) + schema_extractor=schema_extractor) self._curr_group = name self._doc_to_db_processor: DocToDbProcessor = None self._graph_document: weakref.ref = None diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index 9b9c7b13a..b6dd4c0d8 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -5,6 +5,7 @@ from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from functools import cached_property +from itertools import repeat from lazyllm import LOG @@ -95,12 +96,9 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no try: if not input_files: return if not ids: ids = [gen_docid(path) for path in input_files] - if metadatas is None: - metadatas = [{} for _ in input_files] - for metadata, doc_id, path in zip(metadatas, ids, input_files): - metadata.setdefault(RAG_DOC_ID, doc_id) - metadata.setdefault(RAG_DOC_PATH, path) - metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) + temp_metas = [{RAG_DOC_ID: doc_id, RAG_DOC_PATH: path, RAG_KB_ID: kb_id or DEFAULT_KB_ID} + for doc_id, path in zip(ids, input_files)] + metadatas = [{**temp, **(metadata)} for metadata, temp in zip(metadatas or repeat({}), temp_metas)] kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True) schema_futures = [] From 9c714de6a32f7794f07f612fbd982de78eda16d7 Mon Sep 17 00:00:00 2001 From: wangzhihong Date: Fri, 6 Feb 2026 21:50:53 +0800 Subject: [PATCH 06/46] temp --- lazyllm/tools/rag/doc_impl.py | 26 ++++++-------------------- tests/basic_tests/RAG/test_document.py | 2 +- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 2370329ec..0792a070f 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -192,11 +192,9 @@ def _lazy_init(self) -> None: self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, new_status=DocListManager.Status.success) if self._dlm: - self._init_monitor_event = threading.Event() self._daemon = threading.Thread(target=self.worker) self._daemon.daemon = True self._daemon.start() - self._init_monitor_event.wait() def _resolve_index_pending_registrations(self): for index_type, index_cls, index_args, index_kwargs in self._index_pending_registrations: @@ -298,13 +296,11 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None): self._local_file_reader[pattern] = func self._reader._lazy_init.flag.reset() - def _add_doc_to_store_with_status(self, input_files: List[str], ids: List[str], metadatas: List[Dict[str, Any]], - cond_status_list: Optional[List[str]] = None): + def _add_doc_to_store(self, input_files: List[str], ids: List[str], metadatas: List[Dict[str, Any]]): success_ids, failed_ids = [], [] for filepath, doc_id, metadata in zip(input_files, ids or repeat(None), metadatas or repeat(None)): try: - self._add_doc_to_store(input_files=[filepath], ids=[doc_id] if doc_id is not None else None, - metadatas=[metadata] if metadata is not None else None) + self._processor.add_doc([filepath], [doc_id], [metadata] if metadata is not None else None) success_ids.append(doc_id) except Exception as e: LOG.error(f'Error adding document {doc_id} ({filepath}) to store: {e}') @@ -312,10 +308,10 @@ def _add_doc_to_store_with_status(self, input_files: List[str], ids: List[str], if success_ids: self._dlm.update_kb_group(cond_file_ids=success_ids, cond_group=self._kb_group_name, - cond_status_list=cond_status_list, new_status=DocListManager.Status.success) + new_status=DocListManager.Status.success) if failed_ids: self._dlm.update_kb_group(cond_file_ids=failed_ids, cond_group=self._kb_group_name, - cond_status_list=cond_status_list, new_status=DocListManager.Status.failed) + new_status=DocListManager.Status.failed) def _batch_call(self, func: Callable, *args, batch_size: int = 10, **kwargs): batch_count = next((len(arg) for arg in args if isinstance(arg, (tuple, list))), 0) @@ -323,7 +319,6 @@ def _batch_call(self, func: Callable, *args, batch_size: int = 10, **kwargs): func(*[arg[i:i + batch_size] if isinstance(arg, (list, tuple)) else arg for arg in args], **kwargs) def worker(self): - is_first_run = True while True: # Apply meta changes rows = self._dlm.fetch_docs_changed_meta(self._kb_group_name) @@ -341,7 +336,7 @@ def worker(self): self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, new_status=DocListManager.Status.working, new_need_reparse=False) self._delete_doc_from_store(doc_ids=ids) - self._batch_call(self._add_doc_to_store_with_status, filepaths, ids, metadatas, batch_size=10) + self._batch_call(self._add_doc_to_store, filepaths, ids, metadatas, batch_size=10) # Step 2: After doc is deleted from related kb_group, delete doc from db if self._kb_group_name == DocListManager.DEFAULT_GROUP_NAME: @@ -359,12 +354,8 @@ def worker(self): if files: self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, new_status=DocListManager.Status.working) - self._batch_call(self._add_doc_to_store_with_status, - files, ids, metadatas, cond_status_list=[DocListManager.Status.working]) + self._batch_call(self._add_doc_to_store, files, ids, metadatas) - if is_first_run: - self._init_monitor_event.set() - is_first_run = False time.sleep(10) def _list_files( @@ -381,11 +372,6 @@ def _list_files( metadatas.append(json.loads(row[3]) if row[3] else {}) return ids, paths, metadatas - def _add_doc_to_store(self, input_files: List[str], ids: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None): - if not input_files: return - self._processor.add_doc(input_files, ids, metadatas) - def _delete_doc_from_store(self, doc_ids: List[str] = None) -> None: self._processor.delete_doc(doc_ids=doc_ids) diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 649b74d68..5db8f2319 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -85,7 +85,7 @@ def test_add_files(self): new_doc = DocNode(text='new dummy text', group=LAZY_ROOT_NAME) new_doc._global_metadata = {RAG_DOC_ID: gen_docid(self.tmp_file_b.name), RAG_DOC_PATH: self.tmp_file_b.name} self.mock_directory_reader.load_data.return_value = {LAZY_ROOT_NAME: [new_doc], LAZY_IMAGE_GROUP: []} - self.doc_impl._add_doc_to_store([self.tmp_file_b.name]) + self.doc_impl._processor._add_doc([self.tmp_file_b.name]) assert len(self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME)) == 2 class TestDocument(unittest.TestCase): From d40ad667c810a41698da3b0e5658cc5cdfe15bed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Sat, 7 Feb 2026 11:02:13 +0800 Subject: [PATCH 07/46] modify --- .../tools/rag/store/hybrid/sensecore_store.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/lazyllm/tools/rag/store/hybrid/sensecore_store.py b/lazyllm/tools/rag/store/hybrid/sensecore_store.py index 337b89c28..33cde7b7e 100644 --- a/lazyllm/tools/rag/store/hybrid/sensecore_store.py +++ b/lazyllm/tools/rag/store/hybrid/sensecore_store.py @@ -51,6 +51,12 @@ def __init__(self, uri: str = '', **kwargs): self._s3_config = kwargs.get('s3_config') self._image_url_config = kwargs.get('image_url_config') self._uploaded_image_keys = set() + self._path_prefix = kwargs.get('path_prefix') + if not self._path_prefix: + try: + self._path_prefix = config['image_path_prefix'] + except Exception: + self._path_prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') @property def dir(self): @@ -99,11 +105,7 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 continue image_file_name = os.path.basename(image_path) obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_file_name}' - try: - prefix = config['image_path_prefix'] - except Exception: - prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') - file_path = create_file_path(path=image_path, prefix=prefix) + file_path = create_file_path(path=image_path, prefix=self._path_prefix) try: self._upload_image_if_needed(file_path, obj_key) content = content.replace(image_path, obj_key) @@ -124,11 +126,7 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 continue image_name = os.path.basename(image_path) obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_name}' - try: - prefix = config['image_path_prefix'] - except Exception: - prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') - file_path = create_file_path(path=image_path, prefix=prefix) + file_path = create_file_path(path=image_path, prefix=self._path_prefix) try: self._upload_image_if_needed(file_path, obj_key) md_info = md_info.replace(image_path, obj_key) @@ -188,11 +186,7 @@ def _serialize_data(self, data: dict) -> Dict: # noqa: C901 continue image_file_name = os.path.basename(image_path) obj_key = f'lazyllm/images/{kb_id}/{doc_id}/{image_file_name}' - try: - prefix = config['image_path_prefix'] - except Exception: - prefix = os.getenv('RAG_IMAGE_PATH_PREFIX', '') - file_path = create_file_path(path=image_path, prefix=prefix) + file_path = create_file_path(path=image_path, prefix=self._path_prefix) try: self._upload_image_if_needed(file_path, obj_key) answer = answer.replace(image_path, obj_key) From 2e64bda849222300e0db09a1e8950054d1bf4b4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 10 Mar 2026 09:34:22 +0800 Subject: [PATCH 08/46] temp code for mock api --- examples/rag/doc_service_mock_example.py | 128 ++ examples/rag/doc_service_standalone.py | 67 ++ lazyllm/tools/rag/__init__.py | 1 + lazyllm/tools/rag/doc_service/__init__.py | 5 + lazyllm/tools/rag/doc_service/base.py | 286 +++++ lazyllm/tools/rag/doc_service/doc_manager.py | 1042 +++++++++++++++++ lazyllm/tools/rag/doc_service/doc_server.py | 577 +++++++++ .../tools/rag/doc_service/parsing_server.py | 434 +++++++ lazyllm/tools/rag/document.py | 17 +- lazyllm/tools/rag/parsing_service/base.py | 2 + lazyllm/tools/rag/parsing_service/server.py | 6 + lazyllm/tools/rag/parsing_service/worker.py | 6 + lazyllm/tools/rag/utils.py | 10 +- 13 files changed, 2574 insertions(+), 7 deletions(-) create mode 100644 examples/rag/doc_service_mock_example.py create mode 100644 examples/rag/doc_service_standalone.py create mode 100644 lazyllm/tools/rag/doc_service/__init__.py create mode 100644 lazyllm/tools/rag/doc_service/base.py create mode 100644 lazyllm/tools/rag/doc_service/doc_manager.py create mode 100644 lazyllm/tools/rag/doc_service/doc_server.py create mode 100644 lazyllm/tools/rag/doc_service/parsing_server.py diff --git a/examples/rag/doc_service_mock_example.py b/examples/rag/doc_service_mock_example.py new file mode 100644 index 000000000..0fdfd1e53 --- /dev/null +++ b/examples/rag/doc_service_mock_example.py @@ -0,0 +1,128 @@ +'''DocService mock quickstart. + +Run: + python examples/rag/doc_service_mock_example.py +''' + +from __future__ import annotations + +import argparse +import io +import os +import tempfile +import time + +import requests + +from lazyllm import Document + + +def _wait_task(base_url: str, task_id: str, targets: set[str], timeout: float = 10.0): + start = time.time() + while time.time() - start < timeout: + resp = requests.get(f'{base_url}/v1/tasks/{task_id}', timeout=5) + resp.raise_for_status() + task = resp.json()['data'] + if task['status'] in targets: + return task + time.sleep(0.2) + raise TimeoutError(f'task {task_id} did not reach {targets}') + + +def main(): + parser = argparse.ArgumentParser(description='DocService mock quickstart.') + parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API/docs inspection.') + parser.add_argument('--doc-server-port', type=int, default=None, help='DocServer listen port.') + args = parser.parse_args() + + with tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_demo_') as tmp: + storage = os.path.join(tmp, 'uploads') + os.makedirs(storage, exist_ok=True) + seed_path = os.path.join(storage, 'seed.txt') + with open(seed_path, 'w', encoding='utf-8') as f: + f.write('seed content') + + doc = Document( + dataset_path=storage, + manager=True, + name='demo_doc_service', + doc_server_port=args.doc_server_port, + ) + doc.start() + + try: + base_url = doc.manager.url.rsplit('/', 1)[0] + print(f'DocService URL: {base_url}') + print(f'Swagger Docs: {base_url}/docs') + + upload_resp = requests.post( + f'{base_url}/v1/docs/upload', + params={'kb_id': 'kb_demo', 'algo_id': '__default__'}, + files=[('files', ('demo.txt', io.BytesIO(b'hello lazyllm rag'), 'text/plain'))], + timeout=10, + ) + upload_resp.raise_for_status() + upload_item = upload_resp.json()['data']['items'][0] + doc_id = upload_item['doc_id'] + task_id = upload_item['task_id'] + _wait_task(base_url, task_id, {'SUCCESS'}) + + patch_resp = requests.post( + f'{base_url}/v1/docs/metadata/patch', + json={ + 'kb_id': 'kb_demo', + 'algo_id': '__default__', + 'items': [{'doc_id': doc_id, 'patch': {'owner': 'demo_user', 'scene': 'quickstart'}}], + }, + timeout=10, + ) + patch_resp.raise_for_status() + patch_task = patch_resp.json()['data']['items'][0]['task_id'] + _wait_task(base_url, patch_task, {'SUCCESS'}) + + reparse_resp = requests.post( + f'{base_url}/v1/docs/reparse', + json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=10, + ) + reparse_resp.raise_for_status() + reparse_task = reparse_resp.json()['data']['task_ids'][0] + _wait_task(base_url, reparse_task, {'SUCCESS'}) + + add_resp = requests.post( + f'{base_url}/v1/docs/add', + json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'items': [{'file_path': seed_path}]}, + timeout=10, + ) + add_resp.raise_for_status() + add_task = add_resp.json()['data']['items'][0]['task_id'] + _wait_task(base_url, add_task, {'SUCCESS'}) + + docs_resp = requests.get( + f'{base_url}/v1/docs', + params={'kb_id': 'kb_demo', 'include_deleted_or_canceled': False}, + timeout=10, + ) + docs_resp.raise_for_status() + docs = docs_resp.json()['data']['items'] + print(f'Current docs in kb_demo: {len(docs)}') + + delete_resp = requests.post( + f'{base_url}/v1/docs/delete', + json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=10, + ) + delete_resp.raise_for_status() + delete_task = delete_resp.json()['data']['items'][0]['task_id'] + _wait_task(base_url, delete_task, {'DELETED'}) + print('Doc lifecycle demo completed.') + if args.wait: + print('Server is running. Press Ctrl+C to stop...') + while True: + time.sleep(1) + finally: + doc.stop() + + +if __name__ == '__main__': + main() diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py new file mode 100644 index 000000000..3e968f75c --- /dev/null +++ b/examples/rag/doc_service_standalone.py @@ -0,0 +1,67 @@ +'''Start standalone DocService mock server. + +Run: + python examples/rag/doc_service_standalone.py --wait +''' + +from __future__ import annotations + +import argparse +import os +import tempfile +import time + + +def main(): + parser = argparse.ArgumentParser(description='Standalone DocService mock server.') + parser.add_argument('--port', type=int, default=None, help='DocServer listen port.') + parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API/docs inspection.') + args = parser.parse_args() + + from lazyllm.tools.rag.doc_service import DocServer + + tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') + storage_dir = os.path.join(tmp_dir, 'uploads') + os.makedirs(storage_dir, exist_ok=True) + db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(tmp_dir, 'doc_service.db'), + } + parser_db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(tmp_dir, 'doc_service_parser.db'), + } + + server = DocServer( + storage_dir=storage_dir, + db_config=db_config, + parser_db_config=parser_db_config, + port=args.port, + ) + server.start() + base_url = server.url.rsplit('/', 1)[0] + print(f'DocService URL: {base_url}', flush=True) + print(f'Swagger Docs: {base_url}/docs', flush=True) + print(f'Storage Dir: {storage_dir}', flush=True) + print(f'Doc DB: {db_config["db_name"]}', flush=True) + print(f'Parser DB: {parser_db_config["db_name"]}', flush=True) + + try: + if args.wait: + print('Server is running. Press Ctrl+C to stop...', flush=True) + while True: + time.sleep(1) + finally: + server.stop() + + +if __name__ == '__main__': + main() diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index fa1f59bab..86834668d 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -1,4 +1,5 @@ from lazyllm.thirdparty import check_dependency_by_group + check_dependency_by_group('rag') # flake8: noqa: E402 diff --git a/lazyllm/tools/rag/doc_service/__init__.py b/lazyllm/tools/rag/doc_service/__init__.py new file mode 100644 index 000000000..94dc21ff8 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/__init__.py @@ -0,0 +1,5 @@ +from .doc_server import DocServer +from .doc_manager import DocManager +from .parsing_server import ParsingTaskServer + +__all__ = ['DocServer', 'DocManager', 'ParsingTaskServer'] diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py new file mode 100644 index 000000000..a69486db8 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/base.py @@ -0,0 +1,286 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from pydantic import BaseModel, Field +from ..parsing_service.base import TaskType + + +class DocStatus(str, Enum): + WAITING = 'WAITING' + WORKING = 'WORKING' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + CANCELED = 'CANCELED' + DELETING = 'DELETING' + DELETED = 'DELETED' + + +class KBStatus(str, Enum): + ACTIVE = 'ACTIVE' + DELETING = 'DELETING' + DELETED = 'DELETED' + + +class SourceType(str, Enum): + API = 'API' + SCAN = 'SCAN' + TEMP = 'TEMP' + EXTERNAL = 'EXTERNAL' + + +class CallbackEventType(str, Enum): + START = 'START' + FINISH = 'FINISH' + + +BIZ_HTTP_CODE = { + 'E_INVALID_PARAM': 400, + 'E_NOT_FOUND': 404, + 'E_STATE_CONFLICT': 409, + 'E_IDEMPOTENCY_CONFLICT': 409, + 'E_IDEMPOTENCY_IN_PROGRESS': 409, +} + + +class DocServiceError(Exception): + def __init__(self, biz_code: str, msg: str, data: Optional[Dict[str, Any]] = None): + super().__init__(biz_code, msg, data) + self.biz_code = biz_code + self.msg = msg + self.data = data or {} + + @property + def http_status(self): + return BIZ_HTTP_CODE.get(self.biz_code, 500) + + +class TaskCreateRequest(BaseModel): + task_id: str = Field(default_factory=lambda: str(uuid4())) + task_type: TaskType + doc_id: str + kb_id: str = '__default__' + algo_id: str = '__default__' + metadata: Dict[str, Any] = Field(default_factory=dict) + priority: int = 0 + callback_url: Optional[str] = None + + +class TaskCallbackRequest(BaseModel): + callback_id: str = Field(default_factory=lambda: str(uuid4())) + task_id: str + event_type: CallbackEventType + status: DocStatus + error_code: Optional[str] = None + error_msg: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + + +class AddFileItem(BaseModel): + file_path: str + doc_id: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class AddRequest(BaseModel): + items: List[AddFileItem] + kb_id: str = '__default__' + algo_id: str = '__default__' + source_type: SourceType = SourceType.EXTERNAL + idempotency_key: Optional[str] = None + + +class UploadRequest(BaseModel): + items: List[AddFileItem] + kb_id: str = '__default__' + algo_id: str = '__default__' + source_type: SourceType = SourceType.API + idempotency_key: Optional[str] = None + + +class ReparseRequest(BaseModel): + doc_ids: List[str] + kb_id: str = '__default__' + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + +class DeleteRequest(BaseModel): + doc_ids: List[str] + kb_id: str = '__default__' + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + +class TransferItem(BaseModel): + doc_id: str + source_kb_id: str = '__default__' + source_algo_id: str = '__default__' + target_kb_id: str + target_algo_id: str + mode: str = 'copy' + + +class TransferRequest(BaseModel): + items: List[TransferItem] + idempotency_key: Optional[str] = None + + +class MetadataPatchItem(BaseModel): + doc_id: str + patch: Dict[str, Any] = Field(default_factory=dict) + + +class MetadataPatchRequest(BaseModel): + items: List[MetadataPatchItem] + kb_id: str = '__default__' + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + +IDEMPOTENCY_RECORDS_TABLE_INFO = { + 'name': 'lazyllm_idempotency_records', + 'comment': 'Idempotency replay records', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'endpoint', 'data_type': 'string', 'nullable': False, 'comment': 'Endpoint name'}, + {'name': 'idempotency_key', 'data_type': 'string', 'nullable': False, 'comment': 'Idempotency key'}, + {'name': 'req_hash', 'data_type': 'string', 'nullable': False, 'comment': 'Request hash'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Processing status'}, + {'name': 'response_json', 'data_type': 'text', 'nullable': True, 'comment': 'Response json'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +CALLBACK_RECORDS_TABLE_INFO = { + 'name': 'lazyllm_callback_records', + 'comment': 'Processed callback records', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'callback_id', 'data_type': 'string', 'nullable': False, 'comment': 'Callback ID'}, + {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + ], +} + + +DOCUMENTS_TABLE_INFO = { + 'name': 'lazyllm_documents', + 'comment': 'Document metadata table', + 'columns': [ + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'is_primary_key': True, + 'comment': 'Document ID'}, + {'name': 'filename', 'data_type': 'string', 'nullable': False, 'comment': 'Filename'}, + {'name': 'path', 'data_type': 'string', 'nullable': False, 'comment': 'Absolute file path'}, + {'name': 'meta', 'data_type': 'text', 'nullable': True, 'comment': 'Document metadata in JSON string'}, + {'name': 'upload_status', 'data_type': 'string', 'nullable': False, 'comment': 'Document upload status'}, + {'name': 'source_type', 'data_type': 'string', 'nullable': False, 'comment': 'Source type'}, + {'name': 'file_type', 'data_type': 'string', 'nullable': True, 'comment': 'File type suffix'}, + {'name': 'content_hash', 'data_type': 'string', 'nullable': True, 'comment': 'Content hash'}, + {'name': 'size_bytes', 'data_type': 'integer', 'nullable': True, 'comment': 'File size in bytes'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KBS_TABLE_INFO = { + 'name': 'lazyllm_knowledge_bases', + 'comment': 'Knowledge base table', + 'columns': [ + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'is_primary_key': True, + 'comment': 'Knowledge base ID'}, + {'name': 'display_name', 'data_type': 'string', 'nullable': True, 'comment': 'Display name'}, + {'name': 'description', 'data_type': 'string', 'nullable': True, 'comment': 'Description'}, + {'name': 'doc_count', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Document count'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'default': KBStatus.ACTIVE.value, + 'comment': 'KB status'}, + {'name': 'owner_id', 'data_type': 'string', 'nullable': True, 'comment': 'Owner ID'}, + {'name': 'meta', 'data_type': 'text', 'nullable': True, 'comment': 'KB metadata in JSON string'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KB_DOCUMENTS_TABLE_INFO = { + 'name': 'lazyllm_kb_documents', + 'comment': 'KB and document binding table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KB_ALGORITHM_TABLE_INFO = { + 'name': 'lazyllm_kb_algorithm', + 'comment': 'KB and algorithm binding table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +PARSE_STATE_TABLE_INFO = { + 'name': 'lazyllm_doc_parse_state', + 'comment': 'Latest parse state table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Current parse status'}, + {'name': 'current_task_id', 'data_type': 'string', 'nullable': True, 'comment': 'Current task ID'}, + {'name': 'task_type', 'data_type': 'string', 'nullable': True, 'comment': 'Current task type'}, + {'name': 'idempotency_key', 'data_type': 'string', 'nullable': True, 'comment': 'Idempotency key'}, + {'name': 'priority', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Task priority'}, + {'name': 'task_score', 'data_type': 'integer', 'nullable': True, 'comment': 'Task score'}, + {'name': 'retry_count', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Retry count'}, + {'name': 'max_retry', 'data_type': 'integer', 'nullable': False, 'default': 3, 'comment': 'Max retry'}, + {'name': 'lease_owner', 'data_type': 'string', 'nullable': True, 'comment': 'Lease owner'}, + {'name': 'lease_until', 'data_type': 'datetime', 'nullable': True, 'comment': 'Lease deadline'}, + {'name': 'last_error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Last error code'}, + {'name': 'last_error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Last error message'}, + {'name': 'failed_stage', 'data_type': 'string', 'nullable': True, 'comment': 'Failure stage'}, + {'name': 'queued_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Queued time'}, + {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, + {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +def now_ts() -> datetime: + return datetime.now() diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py new file mode 100644 index 000000000..b597b9236 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -0,0 +1,1042 @@ +from __future__ import annotations + +import hashlib +import json +import os +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +import requests +import sqlalchemy +from sqlalchemy.exc import IntegrityError + +from lazyllm.thirdparty import fastapi + +from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict +from ...sql import SqlManager +from .base import ( + AddRequest, + CallbackEventType, + CALLBACK_RECORDS_TABLE_INFO, + DeleteRequest, + DocServiceError, + DOCUMENTS_TABLE_INFO, + IDEMPOTENCY_RECORDS_TABLE_INFO, + KB_ALGORITHM_TABLE_INFO, + KB_DOCUMENTS_TABLE_INFO, + KBStatus, + KBS_TABLE_INFO, + MetadataPatchRequest, + PARSE_STATE_TABLE_INFO, + ReparseRequest, + SourceType, + TaskCallbackRequest, + TaskCreateRequest, + TaskType, + TransferRequest, + UploadRequest, + DocStatus, + now_ts, +) +from .parsing_server import ParsingTaskServer + + +def _to_json(data: Optional[Dict[str, Any]]) -> str: + return json.dumps(data or {}, ensure_ascii=False) + + +def _from_json(raw: Optional[str]) -> Dict[str, Any]: + if not raw: + return {} + try: + return json.loads(raw) + except Exception: + return {} + + +def gen_doc_id(file_path: str, doc_id: Optional[str] = None) -> str: + if doc_id: + return doc_id + return hashlib.sha256(file_path.encode()).hexdigest() + + +def _stable_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, sort_keys=True, default=str) + + +def _hash_payload(data: Any) -> str: + return hashlib.sha256(_stable_json(data).encode()).hexdigest() + + +def _sha256_file(file_path: str) -> str: + digest = hashlib.sha256() + with open(file_path, 'rb') as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b''): + digest.update(chunk) + return digest.hexdigest() + + +class _ParserClient: + def __init__(self, parser_server: Optional[ParsingTaskServer] = None, parser_url: Optional[str] = None): + self._parser_server = parser_server + if parser_url: + parser_url = parser_url.rstrip('/') + if parser_url.endswith('/_call') or parser_url.endswith('/generate'): + parser_url = parser_url.rsplit('/', 1)[0] + self._parser_url = parser_url + else: + self._parser_url = None + + def _post(self, path: str, payload: Dict[str, Any]): + url = f'{self._parser_url}{path}' + resp = requests.post(url, json=payload, timeout=8) + if resp.status_code >= 400: + raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') + return resp.json() + + def _get(self, path: str, params: Optional[Dict[str, Any]] = None): + url = f'{self._parser_url}{path}' + resp = requests.get(url, params=params, timeout=8) + if resp.status_code >= 400: + raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') + return resp.json() + + def create_task(self, req: TaskCreateRequest): + if self._parser_server: + return self._parser_server.create_task(req) + data = self._post('/v1/internal/tasks/create', req.model_dump(mode='json')) + return BaseResponse.model_validate(data) + + def cancel_task(self, task_id: str): + if self._parser_server: + return self._parser_server.cancel_task(task_id) + data = self._post('/v1/internal/tasks/cancel', {'task_id': task_id}) + return BaseResponse.model_validate(data) + + def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): + if self._parser_server: + return self._parser_server.list_tasks(status=status, page=page, page_size=page_size) + params: Dict[str, Any] = {'page': page, 'page_size': page_size} + if status: + params['status'] = status + data = self._get('/v1/tasks', params=params) + return BaseResponse.model_validate(data) + + def get_task(self, task_id: str): + if self._parser_server: + try: + return self._parser_server.get_task(task_id) + except (fastapi.HTTPException, requests.RequestException): + return BaseResponse(code=404, msg='task not found', data=None) + try: + data = self._get(f'/v1/tasks/{task_id}') + return BaseResponse.model_validate(data) + except RuntimeError as exc: + if '404' in str(exc): + return BaseResponse(code=404, msg='task not found', data=None) + raise + + def list_algorithms(self): + if self._parser_server: + return self._parser_server.list_algorithms() + data = self._get('/v1/algo/list') + return BaseResponse.model_validate(data) + + def get_algorithm_groups(self, algo_id: str): + if self._parser_server: + return self._parser_server.get_algorithm_groups(algo_id) + try: + data = self._get(f'/v1/algo/{algo_id}/groups') + return BaseResponse.model_validate(data) + except RuntimeError as exc: + if '404' in str(exc): + return BaseResponse(code=404, msg='algo not found', data=None) + raise + + +class DocManager: + def __init__( + self, + db_config: Optional[Dict[str, Any]] = None, + parser_server: Optional[ParsingTaskServer] = None, + parser_url: Optional[str] = None, + callback_url: Optional[str] = None, + ): + if parser_server is None and not parser_url: + raise ValueError('Either parser_server or parser_url must be provided') + + self._db_config = db_config or _get_default_db_config('doc_service') + self._db_manager = SqlManager( + **self._db_config, + tables_info_dict={'tables': [DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, + KB_ALGORITHM_TABLE_INFO, PARSE_STATE_TABLE_INFO, + IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO]}, + ) + self._ensure_indexes() + self._parser_client = _ParserClient(parser_server=parser_server, parser_url=parser_url) + self._callback_url = callback_url + self._upsert_default_kb() + + def set_callback_url(self, callback_url: str): + self._callback_url = callback_url + + def _ensure_indexes(self): + stmts = [ + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_docs_path ON lazyllm_documents(path)', + 'CREATE INDEX IF NOT EXISTS idx_documents_upload_status ON lazyllm_documents(upload_status)', + 'CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON lazyllm_documents(updated_at)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_display_name ' + 'ON lazyllm_knowledge_bases(display_name) WHERE display_name IS NOT NULL', + 'CREATE INDEX IF NOT EXISTS idx_kb_created_at ON lazyllm_knowledge_bases(created_at)', + 'CREATE INDEX IF NOT EXISTS idx_kb_updated_at ON lazyllm_knowledge_bases(updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_kb_doc_count ON lazyllm_knowledge_bases(doc_count)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_documents ON lazyllm_kb_documents(kb_id, doc_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_documents_doc_id ON lazyllm_kb_documents(doc_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_documents_kb_id ON lazyllm_kb_documents(kb_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_algorithm_kb_id ON lazyllm_kb_algorithm(kb_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_algorithm_algo_id ON lazyllm_kb_algorithm(algo_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_state_key ' + 'ON lazyllm_doc_parse_state(doc_id, kb_id, algo_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_state_current_task ' + 'ON lazyllm_doc_parse_state(current_task_id) WHERE current_task_id IS NOT NULL', + 'CREATE INDEX IF NOT EXISTS idx_parse_sched ' + 'ON lazyllm_doc_parse_state(status, task_score, updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_parse_lease ' + 'ON lazyllm_doc_parse_state(status, lease_until)', + 'CREATE INDEX IF NOT EXISTS idx_parse_kb_algo_status ' + 'ON lazyllm_doc_parse_state(kb_id, algo_id, status)', + 'CREATE INDEX IF NOT EXISTS idx_parse_task_type_status ' + 'ON lazyllm_doc_parse_state(task_type, status, updated_at)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_idempotency_endpoint_key ' + 'ON lazyllm_idempotency_records(endpoint, idempotency_key)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_callback_id ' + 'ON lazyllm_callback_records(callback_id)', + ] + for stmt in stmts: + self._db_manager.execute_commit(stmt) + + def _upsert_default_kb(self): + self._ensure_kb('__default__', display_name='__default__') + self._ensure_kb_algorithm('__default__', '__default__') + self._cleanup_idempotency_records() + + def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None): + now = now_ts() + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if row is None: + row = Kb( + kb_id=kb_id, + display_name=display_name, + description=description, + doc_count=0, + status=KBStatus.ACTIVE.value, + owner_id=owner_id, + meta=_to_json(meta), + created_at=now, + updated_at=now, + ) + else: + if display_name is not None: + row.display_name = display_name + if description is not None: + row.description = description + if owner_id is not None: + row.owner_id = owner_id + if meta is not None: + row.meta = _to_json(meta) + if row.status == KBStatus.DELETED.value: + row.status = KBStatus.ACTIVE.value + row.updated_at = now + session.add(row) + + def _ensure_kb_algorithm(self, kb_id: str, algo_id: str): + now = now_ts() + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id).first() + if row is None: + row = Rel(kb_id=kb_id, algo_id=algo_id, created_at=now, updated_at=now) + elif row.algo_id != algo_id: + raise DocServiceError( + 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {row.algo_id}', + {'kb_id': kb_id, 'bound_algo_id': row.algo_id, 'requested_algo_id': algo_id} + ) + else: + row.updated_at = now + session.add(row) + + def _get_kb(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + return _orm_to_dict(row) if row else None + + def _get_kb_algorithm(self, kb_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id).first() + return _orm_to_dict(row) if row else None + + def _validate_kb_algorithm(self, kb_id: str, algo_id: str): + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + if kb.get('status') != KBStatus.ACTIVE.value: + raise DocServiceError('E_STATE_CONFLICT', f'kb is not active: {kb_id}', {'kb_id': kb_id}) + binding = self._get_kb_algorithm(kb_id) + if binding is None: + raise DocServiceError('E_STATE_CONFLICT', f'kb has no algorithm binding: {kb_id}', {'kb_id': kb_id}) + if binding['algo_id'] != algo_id: + raise DocServiceError( + 'E_INVALID_PARAM', f'kb {kb_id} is bound to algorithm {binding["algo_id"]}', + {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} + ) + return binding + + def _ensure_kb_document(self, kb_id: str, doc_id: str): + now = now_ts() + created = False + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is None: + created = True + row = Rel(kb_id=kb_id, doc_id=doc_id, created_at=now, updated_at=now) + else: + row.updated_at = now + session.add(row) + if created: + self._refresh_kb_doc_count(kb_id) + return created + + def _remove_kb_document(self, kb_id: str, doc_id: str): + removed = False + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is not None: + session.delete(row) + removed = True + if removed: + self._refresh_kb_doc_count(kb_id) + return removed + + def _load_idempotency_record(self, endpoint: str, idempotency_key: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + return _orm_to_dict(row) if row else None + + def _cleanup_idempotency_records(self, ttl_days: int = 7): + cutoff = now_ts() - timedelta(days=ttl_days) + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + session.query(Record).filter(Record.updated_at < cutoff).delete() + + def _claim_idempotency_key(self, endpoint: str, idempotency_key: str, req_hash: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + now = now_ts() + row = Record( + endpoint=endpoint, + idempotency_key=idempotency_key, + req_hash=req_hash, + status='PROCESSING', + response_json=None, + created_at=now, + updated_at=now, + ) + session.add(row) + session.flush() + return True + + def _complete_idempotency_record(self, endpoint: str, idempotency_key: str, response: Any): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + if row is None: + return + row.status = 'COMPLETED' + row.response_json = _stable_json(response) + row.updated_at = now_ts() + session.add(row) + + def _drop_idempotency_claim(self, endpoint: str, idempotency_key: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + if row is not None and row.status == 'PROCESSING': + session.delete(row) + + def run_idempotent(self, endpoint: str, idempotency_key: Optional[str], payload: Any, handler): + if not idempotency_key: + return handler() + req_hash = _hash_payload(payload) + try: + self._claim_idempotency_key(endpoint, idempotency_key, req_hash) + except IntegrityError: + record = self._load_idempotency_record(endpoint, idempotency_key) + if record is None: + raise DocServiceError('E_IDEMPOTENCY_IN_PROGRESS', 'idempotency request is being processed') + if record['req_hash'] != req_hash: + raise DocServiceError('E_IDEMPOTENCY_CONFLICT', 'idempotency key conflicts with different request') + if record.get('status') == 'COMPLETED' and record.get('response_json'): + return json.loads(record['response_json']) + raise DocServiceError( + 'E_IDEMPOTENCY_IN_PROGRESS', 'idempotency request is being processed', + {'endpoint': endpoint, 'idempotency_key': idempotency_key} + ) + try: + response = handler() + except Exception: + self._drop_idempotency_claim(endpoint, idempotency_key) + raise + self._complete_idempotency_record(endpoint, idempotency_key, response) + return response + + def _record_callback(self, callback_id: str, task_id: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(CALLBACK_RECORDS_TABLE_INFO['name']) + session.add(Record(callback_id=callback_id, task_id=task_id, created_at=now_ts())) + try: + session.flush() + return True + except IntegrityError: + session.rollback() + return False + + def _refresh_kb_doc_count(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is None: + return + kb_row.doc_count = session.query(Rel).filter(Rel.kb_id == kb_id).count() + if kb_row.status == KBStatus.DELETING.value and kb_row.doc_count == 0: + kb_row.status = KBStatus.DELETED.value + kb_row.updated_at = now_ts() + session.add(kb_row) + + def _has_kb_document(self, kb_id: str, doc_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + return session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() is not None + + def _doc_relation_count(self, doc_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + return session.query(Rel).filter(Rel.doc_id == doc_id).count() + + def _get_doc(self, doc_id: str): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + return _orm_to_dict(row) if row else None + + def _upsert_doc( + self, + doc_id: str, + filename: str, + path: str, + metadata: Dict[str, Any], + source_type: SourceType, + ): + now = now_ts() + file_type = os.path.splitext(path)[1].lstrip('.').lower() or None + size_bytes = os.path.getsize(path) if os.path.exists(path) else None + content_hash = _sha256_file(path) if os.path.exists(path) else None + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + row = Doc( + doc_id=doc_id, + filename=filename, + path=path, + meta=_to_json(metadata), + upload_status=DocStatus.SUCCESS.value, + source_type=source_type.value, + file_type=file_type, + content_hash=content_hash, + size_bytes=size_bytes, + created_at=now, + updated_at=now, + ) + else: + row.filename = filename + row.path = path + row.meta = _to_json(metadata) + row.upload_status = DocStatus.SUCCESS.value + row.source_type = source_type.value + row.file_type = file_type + row.content_hash = content_hash + row.size_bytes = size_bytes + row.updated_at = now + session.add(row) + return self._get_doc(doc_id) + + def _get_parse_snapshot(self, doc_id: str, kb_id: str, algo_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id, State.algo_id == algo_id) + .first() + ) + return _orm_to_dict(row) if row else None + + def _get_latest_parse_snapshot(self, doc_id: str, kb_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id) + .order_by(State.updated_at.desc(), State.created_at.desc()) + .first() + ) + return _orm_to_dict(row) if row else None + + def _assert_action_allowed(self, doc_id: str, kb_id: str, algo_id: str, action: str): + snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) + if snapshot is None: + return + status = snapshot.get('status') + if status == DocStatus.WORKING.value and action in ('upload', 'reparse', 'delete', 'transfer', 'metadata'): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is WORKING') + if status == DocStatus.DELETING.value and action in ('upload', 'reparse', 'delete', 'transfer', 'metadata'): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is DELETING') + + def _upsert_parse_snapshot( + self, + doc_id: str, + kb_id: str, + algo_id: str, + status: DocStatus, + task_type: Optional[TaskType] = None, + current_task_id: Optional[str] = None, + idempotency_key: Optional[str] = None, + priority: int = 0, + task_score: Optional[int] = None, + retry_count: int = 0, + max_retry: int = 3, + lease_owner: Optional[str] = None, + lease_until: Optional[datetime] = None, + error_code: Optional[str] = None, + error_msg: Optional[str] = None, + failed_stage: Optional[str] = None, + queued_at: Optional[datetime] = None, + started_at: Optional[datetime] = None, + finished_at: Optional[datetime] = None, + ): + now = now_ts() + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id, State.algo_id == algo_id) + .first() + ) + if row is None: + row = State( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=status.value, + task_type=task_type.value if task_type else None, + current_task_id=current_task_id, + idempotency_key=idempotency_key, + priority=priority, + task_score=task_score, + retry_count=retry_count, + max_retry=max_retry, + lease_owner=lease_owner, + lease_until=lease_until, + last_error_code=error_code, + last_error_msg=error_msg, + failed_stage=failed_stage, + queued_at=queued_at, + started_at=started_at, + finished_at=finished_at, + created_at=now, + updated_at=now, + ) + else: + row.status = status.value + if task_type is not None: + row.task_type = task_type.value + row.current_task_id = current_task_id + row.idempotency_key = idempotency_key + row.priority = priority + row.task_score = task_score + row.retry_count = retry_count + row.max_retry = max_retry + row.lease_owner = lease_owner + row.lease_until = lease_until + row.last_error_code = error_code + row.last_error_msg = error_msg + row.failed_stage = failed_stage + row.queued_at = queued_at + row.started_at = started_at + row.finished_at = finished_at + row.updated_at = now + session.add(row) + return self._get_parse_snapshot(doc_id, kb_id, algo_id) + + def _create_parser_task(self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType): + task_id = str(uuid4()) + req = TaskCreateRequest( + task_id=task_id, + task_type=task_type, + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + callback_url=self._callback_url, + ) + task_resp = self._parser_client.create_task(req) + if task_resp.code != 200: + raise RuntimeError(f'failed to enqueue parser task: {task_resp.msg}') + return task_id + + def _enqueue_task( + self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, + idempotency_key: Optional[str] = None, priority: int = 0, + ): + task_id = self._create_parser_task(doc_id, kb_id, algo_id, task_type) + parse_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING + snapshot = self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=parse_status, + task_type=task_type, + current_task_id=task_id, + idempotency_key=idempotency_key, + priority=priority, + queued_at=now_ts(), + started_at=None, + finished_at=None, + error_code=None, + error_msg=None, + failed_stage=None, + ) + return task_id, snapshot + + def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + items: List[Dict[str, Any]] = [] + for item in request.items: + file_path = item.file_path + if not os.path.exists(file_path): + raise DocServiceError('E_INVALID_PARAM', f'file not found: {file_path}') + doc_id = gen_doc_id(file_path, doc_id=item.doc_id) + if self._has_kb_document(request.kb_id, doc_id): + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'upload') + doc = self._upsert_doc( + doc_id=doc_id, + filename=os.path.basename(file_path), + path=file_path, + metadata=item.metadata, + source_type=request.source_type, + ) + self._ensure_kb_document(request.kb_id, doc_id) + task_id, snapshot = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_ADD, + idempotency_key=request.idempotency_key, + ) + items.append({ + 'doc_id': doc_id, + 'kb_id': request.kb_id, + 'algo_id': request.algo_id, + 'upload_status': doc['upload_status'], + 'parse_status': snapshot['status'], + 'task_id': task_id, + }) + return items + + def add_files(self, request: AddRequest) -> List[Dict[str, Any]]: + return self.upload(UploadRequest( + items=request.items, + kb_id=request.kb_id, + algo_id=request.algo_id, + source_type=request.source_type, + idempotency_key=request.idempotency_key, + )) + + def reparse(self, request: ReparseRequest) -> List[str]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + task_ids = [] + for doc_id in request.doc_ids: + if self._get_doc(doc_id) is None or not self._has_kb_document(request.kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'reparse') + task_id, _ = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_REPARSE, + idempotency_key=request.idempotency_key, + ) + task_ids.append(task_id) + return task_ids + + def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + items: List[Dict[str, Any]] = [] + for doc_id in request.doc_ids: + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'delete') + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if self._doc_relation_count(doc_id) <= 1: + row.upload_status = DocStatus.DELETING.value + row.updated_at = now_ts() + session.add(row) + + task_id, snapshot = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_DELETE, + idempotency_key=request.idempotency_key, + ) + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': task_id, + 'status': snapshot['status'], + 'error_code': None, + }) + return items + + def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: + items: List[Dict[str, Any]] = [] + for item in request.items: + if item.mode not in ('move', 'copy'): + raise DocServiceError( + 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} + ) + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(item.source_kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) + self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) + self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') + self._ensure_kb_document(item.target_kb_id, item.doc_id) + if item.mode == 'move': + self._remove_kb_document(item.source_kb_id, item.doc_id) + task_id, snapshot = self._enqueue_task( + item.doc_id, item.target_kb_id, item.target_algo_id, TaskType.DOC_TRANSFER, + idempotency_key=request.idempotency_key, + ) + items.append({ + 'doc_id': item.doc_id, + 'task_id': task_id, + 'source_kb_id': item.source_kb_id, + 'target_kb_id': item.target_kb_id, + 'source_algo_id': item.source_algo_id, + 'target_algo_id': item.target_algo_id, + 'mode': item.mode, + 'status': snapshot['status'], + }) + return items + + def patch_metadata(self, request: MetadataPatchRequest): + self._validate_kb_algorithm(request.kb_id, request.algo_id) + updated = [] + failed = [] + for item in request.items: + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._assert_action_allowed(item.doc_id, request.kb_id, request.algo_id, 'metadata') + merged = _from_json(doc.get('meta')) + merged.update(item.patch) + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == item.doc_id).first() + row.meta = _to_json(merged) + row.updated_at = now_ts() + session.add(row) + + task_id, _ = self._enqueue_task( + item.doc_id, request.kb_id, request.algo_id, TaskType.DOC_UPDATE_META, + idempotency_key=request.idempotency_key, + ) + updated.append({'doc_id': item.doc_id, 'task_id': task_id}) + return { + 'updated_count': len(updated), + 'doc_ids': [u['doc_id'] for u in updated], + 'failed_items': failed, + 'items': updated, + } + + def _sync_doc_upload_status(self, doc_id: str): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return + has_rel = session.query(Rel).filter(Rel.doc_id == doc_id).first() is not None + row.upload_status = DocStatus.SUCCESS.value if has_rel else DocStatus.DELETED.value + row.updated_at = now_ts() + session.add(row) + + def on_task_callback(self, callback: TaskCallbackRequest): + if not self._record_callback(callback.callback_id, callback.task_id): + return {'ack': True, 'deduped': True, 'ignored_reason': None} + task = self._parser_client.get_task(callback.task_id) + if task.code != 200: + return {'ack': True, 'ignored_reason': 'task_not_found'} + task_data = task.data + doc_id = task_data['doc_id'] + kb_id = task_data['kb_id'] + algo_id = task_data['algo_id'] + task_type = TaskType(task_data['task_type']) + snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) + if snapshot and snapshot.get('current_task_id') and snapshot['current_task_id'] != callback.task_id: + return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} + + if callback.event_type == CallbackEventType.START: + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=DocStatus.WORKING, + task_type=task_type, + current_task_id=callback.task_id, + started_at=now_ts(), + queued_at=None, + finished_at=None, + ) + return {'ack': True, 'deduped': False, 'ignored_reason': None} + + final_status = callback.status + failed_stage = None + if final_status == DocStatus.FAILED: + failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' + + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=final_status, + task_type=task_type, + current_task_id=callback.task_id, + error_code=callback.error_code, + error_msg=callback.error_msg, + failed_stage=failed_stage, + finished_at=now_ts(), + ) + + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: + self._remove_kb_document(kb_id, doc_id) + self._sync_doc_upload_status(doc_id) + + return {'ack': True, 'deduped': False, 'ignored_reason': None} + + def list_docs( + self, + status: Optional[List[str]] = None, + kb_id: Optional[str] = None, + algo_id: Optional[str] = None, + keyword: Optional[str] = None, + include_deleted_or_canceled: bool = True, + page: int = 1, + page_size: int = 20, + ): + page = max(page, 1) + page_size = max(1, page_size) + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + query = session.query(Doc, Rel).join(Rel, Doc.doc_id == Rel.doc_id) + + if kb_id: + query = query.filter(Rel.kb_id == kb_id) + if keyword: + like_expr = f'%{keyword}%' + query = query.filter((Doc.filename.like(like_expr)) | (Doc.path.like(like_expr))) + if not include_deleted_or_canceled: + query = query.filter(~Doc.upload_status.in_([DocStatus.DELETED.value, DocStatus.CANCELED.value])) + + rows = query.order_by(Rel.updated_at.desc(), Doc.updated_at.desc()).all() + items = [] + for doc_row, rel_row in rows: + doc = _orm_to_dict(doc_row) + relation = _orm_to_dict(rel_row) + snapshot = ( + self._get_parse_snapshot(doc['doc_id'], relation['kb_id'], algo_id) + if algo_id else + self._get_latest_parse_snapshot(doc['doc_id'], relation['kb_id']) + ) + if status and (snapshot is None or snapshot.get('status') not in status): + continue + doc['metadata'] = _from_json(doc.get('meta')) + items.append({'doc': doc, 'relation': relation, 'snapshot': snapshot}) + total = len(items) + page_items = items[(page - 1) * page_size:page * page_size] + return {'items': page_items, 'total': total, 'page': page, 'page_size': page_size} + + def get_doc_detail(self, doc_id: str): + doc = self._get_doc(doc_id) + if not doc: + raise DocServiceError('E_NOT_FOUND', f'doc not found: {doc_id}', {'doc_id': doc_id}) + + doc['metadata'] = _from_json(doc.get('meta')) + snapshots = self.list_docs(page=1, page_size=2000)['items'] + matched_items = [] + for item in snapshots: + if item['doc']['doc_id'] == doc_id: + matched_items.append(item) + relation = matched_items[0].get('relation') if matched_items else None + snapshot = matched_items[0].get('snapshot') if matched_items else None + latest_task = None + if snapshot and snapshot.get('current_task_id'): + latest_task_resp = self._parser_client.get_task(snapshot['current_task_id']) + if latest_task_resp.code == 200: + latest_task = latest_task_resp.data + return { + 'doc': doc, + 'relation': relation, + 'snapshot': snapshot, + 'latest_task': latest_task, + 'relations': [item.get('relation') for item in matched_items], + 'snapshots': [item.get('snapshot') for item in matched_items if item.get('snapshot') is not None], + } + + def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): + return self._parser_client.list_tasks(status=status, page=page, page_size=page_size) + + def get_task(self, task_id: str): + return self._parser_client.get_task(task_id) + + def get_tasks_batch(self, task_ids: List[str]): + items = [] + for task_id in task_ids: + resp = self._parser_client.get_task(task_id) + if resp.code == 200 and resp.data is not None: + items.append(resp.data) + return {'items': items} + + def cancel_task(self, task_id: str): + return self._parser_client.cancel_task(task_id) + + def list_algorithms(self): + resp = self._parser_client.list_algorithms() + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + return resp.data + + def get_algo_groups(self, algo_id: str): + resp = self._parser_client.get_algorithm_groups(algo_id) + if resp.code == 404: + raise fastapi.HTTPException(status_code=404, detail='algo not found') + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + return resp.data + + def list_algorithms_compat(self): + items = self.list_algorithms() + return {'items': items} + + def get_algorithm_info(self, algo_id: str): + algorithms = self.list_algorithms() + for item in algorithms: + if item['algo_id'] == algo_id: + data = dict(item) + data['groups'] = self.get_algo_groups(algo_id) + return data + raise DocServiceError('E_NOT_FOUND', f'algo not found: {algo_id}') + + def list_chunks(self, page: int = 1, page_size: int = 20): + return {'items': [], 'total': 0, 'page': page, 'page_size': page_size} + + def health(self): + return { + 'status': 'ok', + 'version': 'v1-mock', + 'deps': {'sql': True}, + } + + def list_kbs(self): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + rows = session.query(Kb).order_by(Kb.updated_at.desc()).all() + items = [] + for row in rows: + items.append({ + 'kb_id': row.kb_id, + 'display_name': row.display_name, + 'description': row.description, + 'doc_count': row.doc_count, + 'status': row.status, + 'owner_id': row.owner_id, + 'meta': _from_json(row.meta), + 'created_at': row.created_at, + 'updated_at': row.updated_at, + }) + return {'items': items} + + def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + binding = self._get_kb_algorithm(kb_id) + if binding is not None and binding['algo_id'] != algo_id: + raise DocServiceError( + 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {binding["algo_id"]}', + {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} + ) + self._ensure_kb(kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta) + self._ensure_kb_algorithm(kb_id, algo_id) + return {'kb_id': kb_id, 'status': KBStatus.ACTIVE.value} + + def delete_kb(self, kb_id: str): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + Snap = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + states = ( + session.query(Snap) + .join(Rel, sqlalchemy.and_(Snap.doc_id == Rel.doc_id, Snap.kb_id == Rel.kb_id)) + .filter(Rel.kb_id == kb_id, ~Snap.status.in_([DocStatus.DELETED.value, DocStatus.CANCELED.value])) + .all() + ) + task_ids = [] + for row in states: + task_id, _ = self._enqueue_task(row.doc_id, row.kb_id, row.algo_id, TaskType.DOC_DELETE) + task_ids.append(task_id) + new_status = KBStatus.DELETING.value if task_ids else KBStatus.DELETED.value + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + kb_row.status = new_status + kb_row.updated_at = now_ts() + session.add(kb_row) + return {'kb_id': kb_id, 'status': new_status, 'task_ids': task_ids} + + def delete_kbs(self, kb_ids: List[str]): + if not kb_ids: + raise DocServiceError('E_INVALID_PARAM', 'kb_ids is required', {'kb_ids': kb_ids}) + items = [] + for kb_id in kb_ids: + items.append(self.delete_kb(kb_id)) + return {'items': items} diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py new file mode 100644 index 000000000..696fe3d0e --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -0,0 +1,577 @@ +from __future__ import annotations + +import hashlib +import json +import os +import traceback +from typing import Any, Dict, List, Optional + +from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, UrlModule, once_wrapper +from lazyllm.thirdparty import fastapi + +from ..utils import BaseResponse, _get_default_db_config, ensure_call_endpoint +from .base import AddRequest, DeleteRequest, DocServiceError, MetadataPatchRequest, ReparseRequest +from .base import SourceType, TaskCallbackRequest +from .base import TransferRequest +from .base import UploadRequest, AddFileItem +from .doc_manager import DocManager +from .parsing_server import ParsingTaskServer + + +class DocServer(ModuleBase): + class _Impl: + def __init__( + self, + storage_dir: str, + db_config: Optional[Dict[str, Any]] = None, + parser_db_config: Optional[Dict[str, Any]] = None, + parser_poll_interval: float = 0.05, + parser_url: Optional[str] = None, + callback_url: Optional[str] = None, + ): + self._storage_dir = storage_dir + self._db_config = db_config + self._parser_db_config = parser_db_config + self._parser_poll_interval = parser_poll_interval + self._parser_url = parser_url + self._callback_url = callback_url + self._parser = None + self._manager = None + + @once_wrapper(reset_on_pickle=True) + def _lazy_init(self): + os.makedirs(self._storage_dir, exist_ok=True) + if self._parser_url: + self._manager = DocManager( + db_config=self._db_config, + parser_url=self._parser_url, + callback_url=self._callback_url, + ) + else: + parser_db = self._parser_db_config or _get_default_db_config('doc_service_parser') + self._parser = ParsingTaskServer(db_config=parser_db, poll_interval=self._parser_poll_interval) + self._parser.start() + self._manager = DocManager( + db_config=self._db_config, + parser_server=self._parser, + callback_url=self._callback_url, + ) + + def stop(self): + self._lazy_init() + if self._parser: + self._parser.stop() + + def set_runtime_callback_url(self, callback_url: str): + self._lazy_init() + self._manager.set_callback_url(callback_url) + + @staticmethod + def _response(data=None, code=200, msg='success', status_code=200): + payload = BaseResponse(code=code, msg=msg, data=data).model_dump(mode='json') + return fastapi.responses.JSONResponse(status_code=status_code, content=payload) + + def _run(self, func, *args, success_msg='success', **kwargs): + try: + data = func(*args, **kwargs) + return self._response(data=data, msg=success_msg) + except DocServiceError as exc: + data = dict(exc.data or {}) + data.setdefault('biz_code', exc.biz_code) + return self._response(data=data, code=exc.http_status, msg=exc.msg, status_code=exc.http_status) + except fastapi.HTTPException as exc: + detail = exc.detail if isinstance(exc.detail, dict) else {} + data = detail.get('data') + if isinstance(data, dict) and 'biz_code' not in data and detail.get('code'): + data['biz_code'] = detail['code'] + code = exc.status_code + msg = detail.get('msg', str(exc.detail)) + return self._response(data=data, code=code, msg=msg, status_code=exc.status_code) + + @staticmethod + def _build_upload_payload(request: UploadRequest, file_identities: Optional[List[Dict[str, Any]]] = None): + items = file_identities + if items is None: + items = [] + for idx, item in enumerate(request.items): + content_hash = None + size_bytes = None + if os.path.exists(item.file_path): + with open(item.file_path, 'rb') as fh: + content = fh.read() + content_hash = hashlib.sha256(content).hexdigest() + size_bytes = len(content) + items.append({ + 'filename': os.path.basename(item.file_path), + 'content_hash': content_hash, + 'size_bytes': size_bytes, + 'doc_id': item.doc_id if idx == 0 else None, + }) + return { + 'kb_id': request.kb_id, + 'algo_id': request.algo_id, + 'source_type': request.source_type.value, + 'idempotency_key': request.idempotency_key, + 'items': items, + } + + def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = None): + safe_name = os.path.basename(filename) or 'upload.bin' + file_path = os.path.join(self._storage_dir, safe_name) + reserved_paths = reserved_paths or set() + if file_path not in reserved_paths and not os.path.exists(file_path): + return file_path + + suffix = os.path.splitext(safe_name)[1] + prefix = safe_name[:-len(suffix)] if suffix else safe_name + for idx in range(1, 10000): + candidate = os.path.join(self._storage_dir, f'{prefix}-{idx}{suffix}') + if candidate not in reserved_paths and not os.path.exists(candidate): + return candidate + return os.path.join(self._storage_dir, f'{prefix}-{hashlib.sha256(safe_name.encode()).hexdigest()[:8]}{suffix}') + + def _run_upload(self, request: UploadRequest, payload: Optional[Dict[str, Any]] = None): + idem_payload = payload or self._build_upload_payload(request) + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/upload', request.idempotency_key, idem_payload, + lambda: {'items': self._manager.upload(request)} + )) + + def upload_request(self, request: UploadRequest): + self._lazy_init() + return self._run_upload(request) + + @app.post('/v1/docs/upload') + async def upload( + self, + request: 'fastapi.Request', + kb_id: str = '__default__', + algo_id: str = '__default__', + source_type: SourceType = SourceType.API, + doc_id: Optional[str] = None, + idempotency_key: Optional[str] = None, + ): + self._lazy_init() + form = await request.form() + files = form.getlist('files') + if not files: + raise fastapi.HTTPException(status_code=400, detail='files is required') + buffered_files = [] + file_identities = [] + for idx, file in enumerate(files): + filename = getattr(file, 'filename', None) or str(getattr(file, 'name', 'upload.bin')) + content = await file.read() if hasattr(file, 'read') else file.file.read() + buffered_files.append({'filename': filename, 'content': content}) + file_identities.append({ + 'filename': filename, + 'content_hash': hashlib.sha256(content).hexdigest(), + 'size_bytes': len(content), + 'doc_id': doc_id if idx == 0 else None, + }) + + def _handle_upload(): + saved_paths = [] + reserved_paths = set() + for item in buffered_files: + file_path = self._gen_unique_upload_path(item['filename'], reserved_paths) + with open(file_path, 'wb') as fh: + fh.write(item['content']) + saved_paths.append(file_path) + reserved_paths.add(file_path) + upload_request = UploadRequest( + items=[AddFileItem(file_path=path, doc_id=(doc_id if idx == 0 else None)) + for idx, path in enumerate(saved_paths)], + kb_id=kb_id, + algo_id=algo_id, + source_type=source_type, + idempotency_key=idempotency_key, + ) + return {'items': self._manager.upload(upload_request)} + + payload = { + 'kb_id': kb_id, + 'algo_id': algo_id, + 'source_type': source_type.value, + 'idempotency_key': idempotency_key, + 'items': file_identities, + } + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/upload', idempotency_key, payload, _handle_upload + )) + + @app.post('/v1/docs/add') + def add(self, request: AddRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/add', request.idempotency_key, payload, lambda: {'items': self._manager.add_files(request)} + )) + + @app.post('/v1/docs/reparse') + def reparse(self, request: ReparseRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/reparse', request.idempotency_key, payload, + lambda: {'task_ids': self._manager.reparse(request)} + )) + + @app.post('/v1/docs/delete') + def delete(self, request: DeleteRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/delete', request.idempotency_key, payload, lambda: {'items': self._manager.delete(request)} + )) + + @app.post('/v1/docs/transfer') + def transfer(self, request: TransferRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/transfer', request.idempotency_key, payload, + lambda: {'items': self._manager.transfer(request)} + )) + + @app.get('/v1/docs') + def list_docs( + self, + status: Optional[List[str]] = None, + kb_id: Optional[str] = None, + algo_id: Optional[str] = None, + keyword: Optional[str] = None, + include_deleted_or_canceled: bool = True, + page: int = 1, + page_size: int = 20, + ): + self._lazy_init() + data = self._manager.list_docs( + status=status, + kb_id=kb_id, + algo_id=algo_id, + keyword=keyword, + include_deleted_or_canceled=include_deleted_or_canceled, + page=page, + page_size=page_size, + ) + return BaseResponse(code=200, msg='success', data=data) + + @app.get('/v1/docs/{doc_id}') + def get_doc(self, doc_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_doc_detail(doc_id)) + + @app.post('/v1/docs/metadata/patch') + def patch_metadata(self, request: MetadataPatchRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/metadata/patch', request.idempotency_key, payload, + lambda: self._manager.patch_metadata(request) + )) + + @app.get('/v1/tasks') + def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): + self._lazy_init() + resp = self._manager.list_tasks(status, page, page_size) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + @app.get('/v1/tasks/{task_id}') + def get_task(self, task_id: str): + self._lazy_init() + resp = self._manager.get_task(task_id) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + def cancel_task_by_id(self, task_id: str): + self._lazy_init() + resp = self._manager.cancel_task(task_id) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + @app.post('/v1/tasks/cancel') + async def cancel_task(self, request: 'fastapi.Request'): + payload = await request.json() + task_id = payload.get('task_id') + if not task_id: + raise fastapi.HTTPException(status_code=400, detail='task_id is required') + idempotency_key = payload.get('idempotency_key') + + def _cancel(): + resp = self._manager.cancel_task(task_id) + if resp.code == 404: + raise DocServiceError('E_NOT_FOUND', resp.msg, resp.data) + if resp.code == 409: + raise DocServiceError('E_STATE_CONFLICT', resp.msg, resp.data) + if resp.code != 200: + raise DocServiceError('E_INVALID_PARAM', resp.msg, resp.data) + return resp.data + return self._run(lambda: self._manager.run_idempotent( + '/v1/tasks/cancel', idempotency_key, payload, _cancel + )) + + @app.post('/v1/internal/callbacks/tasks') + def task_callback(self, callback: TaskCallbackRequest): + self._lazy_init() + return self._run(lambda: self._manager.on_task_callback(callback)) + + @app.get('/v1/algo/list') + def list_algo(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms()) + + @app.get('/v1/algo/{algo_id}/groups') + def get_algo_groups(self, algo_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_algo_groups(algo_id)) + + @app.get('/v1/algorithms') + def list_algorithms(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms_compat()) + + def list_algorithms_impl(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms_compat()) + + @app.post('/v1/algorithms/info') + async def get_algorithm_info(self, request: 'fastapi.Request'): + self._lazy_init() + payload = await request.json() + algo_id = payload.get('algo_id') + if not algo_id: + return self._response(data={'biz_code': 'E_INVALID_PARAM'}, code=400, + msg='algo_id is required', status_code=400) + return self._run(lambda: self._manager.get_algorithm_info(algo_id)) + + def get_algorithm_info_impl(self, algo_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_algorithm_info(algo_id)) + + @app.get('/v1/chunks') + def list_chunks(self, page: int = 1, page_size: int = 20): + self._lazy_init() + return self._run(lambda: self._manager.list_chunks(page=page, page_size=page_size)) + + @app.post('/v1/tasks/batch') + async def get_tasks_batch(self, request: 'fastapi.Request'): + self._lazy_init() + payload = await request.json() + task_ids = payload.get('task_ids') or [] + return self._run(lambda: self._manager.get_tasks_batch(task_ids)) + + def get_tasks_batch_impl(self, task_ids: List[str]): + self._lazy_init() + return self._run(lambda: self._manager.get_tasks_batch(task_ids)) + + @app.post('/v1/tasks/info') + async def get_task_info(self, request: 'fastapi.Request'): + self._lazy_init() + payload = await request.json() + task_id = payload.get('task_id') + if not task_id: + return self._response(data={'biz_code': 'E_INVALID_PARAM'}, code=400, + msg='task_id is required', status_code=400) + resp = self._manager.get_task(task_id) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + def get_task_info_impl(self, task_id: str): + self._lazy_init() + resp = self._manager.get_task(task_id) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + @app.get('/v1/kbs') + def list_kbs(self): + self._lazy_init() + return self._run(lambda: self._manager.list_kbs()) + + def create_kb_by_id(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + self._lazy_init() + return self._run(lambda: self._manager.create_kb( + kb_id, + display_name=display_name, + description=description, + owner_id=owner_id, + meta=meta, + algo_id=algo_id, + )) + + @app.post('/v1/kbs') + async def create_kb(self, request: 'fastapi.Request'): + payload = await request.json() + idempotency_key = payload.get('idempotency_key') + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs', idempotency_key, payload, + lambda: self._manager.create_kb( + payload.get('kb_id'), + display_name=payload.get('display_name'), + description=payload.get('description'), + owner_id=payload.get('owner_id'), + meta=payload.get('meta'), + algo_id=payload.get('algo_id', '__default__'), + ) + )) + + @app.delete('/v1/kbs/{kb_id}') + def delete_kb(self, kb_id: str, idempotency_key: Optional[str] = None): + self._lazy_init() + payload = {'kb_id': kb_id} + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs/{kb_id}:delete', idempotency_key, payload, lambda: self._manager.delete_kb(kb_id) + )) + + @app.delete('/v1/kbs') + async def delete_kbs(self, request: 'fastapi.Request'): + self._lazy_init() + payload = await request.json() + kb_ids = payload.get('kb_ids') or [] + idempotency_key = payload.get('idempotency_key') + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs:delete', idempotency_key, payload, lambda: self._manager.delete_kbs(kb_ids) + )) + + def delete_kbs_impl(self, kb_ids: List[str], idempotency_key: Optional[str] = None): + self._lazy_init() + payload = {'kb_ids': kb_ids} + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs:delete', idempotency_key, payload, lambda: self._manager.delete_kbs(kb_ids) + )) + + @app.get('/v1/health') + def health(self): + self._lazy_init() + return BaseResponse(code=200, msg='success', data=self._manager.health()) + + def __call__(self, func_name: str, *args, **kwargs): + return getattr(self, func_name)(*args, **kwargs) + + def __init__( + self, + port: Optional[int] = None, + url: Optional[str] = None, + parser_url: Optional[str] = None, + db_config: Optional[Dict[str, Any]] = None, + parser_db_config: Optional[Dict[str, Any]] = None, + parser_poll_interval: float = 0.05, + storage_dir: Optional[str] = None, + callback_url: Optional[str] = None, + launcher=None, + ): + super().__init__() + self._raw_impl = None + self._storage_dir = storage_dir or os.path.join(os.getcwd(), '.doc_service_uploads') + self._db_config = db_config or _get_default_db_config('doc_service') + self._parser_db_config = parser_db_config or _get_default_db_config('doc_service_parser') + if url: + self._impl = UrlModule(url=ensure_call_endpoint(url)) + else: + self._raw_impl = DocServer._Impl( + storage_dir=self._storage_dir, + db_config=self._db_config, + parser_db_config=self._parser_db_config, + parser_poll_interval=parser_poll_interval, + parser_url=parser_url, + callback_url=callback_url, + ) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + + def start(self): + result = super().start() + if self._raw_impl and isinstance(self._impl, ServerModule): + try: + callback_url = self._impl._url.rsplit('/', 1)[0] + '/v1/internal/callbacks/tasks' + self._dispatch('set_runtime_callback_url', callback_url) + except Exception as exc: + LOG.warning(f'[DocServer] failed to set runtime callback url: {exc}') + return result + + def stop(self): + if self._raw_impl: + try: + self._dispatch('stop') + except Exception as exc: + LOG.warning(f'[DocServer] stop impl failed: {exc}, {traceback.format_exc()}') + if isinstance(self._impl, ServerModule): + self._impl.stop() + + @property + def url(self): + return self._impl._url + + @property + def _url(self): + return self.url + + @staticmethod + def _normalize_dispatch_result(result): + if isinstance(result, fastapi.responses.JSONResponse): + return json.loads(result.body.decode()) + return result + + def _dispatch(self, method: str, *args, **kwargs): + if isinstance(self._impl, ServerModule): + return self._normalize_dispatch_result(self._impl._call(method, *args, **kwargs)) + return self._normalize_dispatch_result(getattr(self._impl, method)(*args, **kwargs)) + + # Method-call style wrappers + def upload(self, request: UploadRequest): + return self._dispatch('upload_request', request) + + def add(self, request: AddRequest): + return self._dispatch('add', request) + + def reparse(self, request: ReparseRequest): + return self._dispatch('reparse', request) + + def delete(self, request: DeleteRequest): + return self._dispatch('delete', request) + + def transfer(self, request: TransferRequest): + return self._dispatch('transfer', request) + + def patch_metadata(self, request: MetadataPatchRequest): + return self._dispatch('patch_metadata', request) + + def list_docs(self, **kwargs): + return self._dispatch('list_docs', **kwargs) + + def get_doc(self, doc_id: str): + return self._dispatch('get_doc', doc_id) + + def list_tasks(self, **kwargs): + return self._dispatch('list_tasks', **kwargs) + + def get_tasks_batch(self, task_ids: List[str]): + return self._dispatch('get_tasks_batch_impl', task_ids) + + def get_task_info(self, task_id: str): + return self._dispatch('get_task_info_impl', task_id) + + def get_task(self, task_id: str): + return self._dispatch('get_task', task_id) + + def cancel_task(self, task_id: str): + return self._dispatch('cancel_task_by_id', task_id) + + def list_kbs(self): + return self._dispatch('list_kbs') + + def list_chunks(self, page: int = 1, page_size: int = 20): + return self._dispatch('list_chunks', page, page_size) + + def list_algorithms(self): + return self._dispatch('list_algorithms_impl') + + def get_algorithm_info(self, algo_id: str): + return self._dispatch('get_algorithm_info_impl', algo_id) + + def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + return self._dispatch('create_kb_by_id', kb_id, display_name, description, owner_id, meta, algo_id) + + def delete_kb(self, kb_id: str): + return self._dispatch('delete_kb', kb_id) + + def delete_kbs(self, kb_ids: List[str]): + return self._dispatch('delete_kbs_impl', kb_ids) diff --git a/lazyllm/tools/rag/doc_service/parsing_server.py b/lazyllm/tools/rag/doc_service/parsing_server.py new file mode 100644 index 000000000..d5cca0766 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/parsing_server.py @@ -0,0 +1,434 @@ +from __future__ import annotations +''' +Mock parsing execution service used in phase-1 refactor validation. + +Note: +This is intentionally isolated from `lazyllm.tools.rag.parsing_service` so that +DocService API contract and state machine can be validated without requiring the +full parser runtime and algorithm registry. +''' + +import json +import threading +import time +import traceback +from typing import Any, Callable, Dict, List, Optional +from datetime import datetime + +import cloudpickle +import requests + +from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, UrlModule, once_wrapper +from lazyllm.thirdparty import fastapi + +from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict, ensure_call_endpoint +from ...sql import SqlManager +from ..parsing_service.base import ALGORITHM_TABLE_INFO +from .base import ( + CallbackEventType, + TaskCallbackRequest, + TaskCreateRequest, + DocStatus, + now_ts, +) + +PARSER_TASK_TABLE_INFO = { + 'name': 'lazyllm_parse_tasks', + 'comment': 'Parse task table for mock parser service', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, + {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Task status'}, + {'name': 'priority', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Task priority'}, + {'name': 'message', 'data_type': 'text', 'nullable': False, 'comment': 'Task payload'}, + {'name': 'callback_url', 'data_type': 'string', 'nullable': True, 'comment': 'Callback URL'}, + {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Error code'}, + {'name': 'error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Error message'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, + {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, + ], +} + + +class ParsingTaskServer(ModuleBase): + class _Impl: + def __init__( + self, + db_config: Optional[Dict[str, Any]] = None, + poll_interval: float = 0.05, + callback_func: Optional[Callable[[TaskCallbackRequest], None]] = None, + ): + self._db_config = db_config or _get_default_db_config('doc_service_parser') + self._poll_interval = poll_interval + self._db_manager = None + self._task_thread = None + self._shutdown = False + self._callback_func = callback_func + + @once_wrapper(reset_on_pickle=True) + def _lazy_init(self): + self._db_manager = SqlManager( + **self._db_config, + tables_info_dict={'tables': [PARSER_TASK_TABLE_INFO, ALGORITHM_TABLE_INFO]}, + ) + self._ensure_indexes() + self._upsert_default_algorithm() + self._shutdown = False + self._task_thread = threading.Thread(target=self._task_worker, daemon=True) + self._task_thread.start() + + def stop(self): + self._shutdown = True + if self._task_thread and self._task_thread.is_alive(): + self._task_thread.join(timeout=2) + + def _ensure_indexes(self): + stmts = [ + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_tasks_task_id ON lazyllm_parse_tasks(task_id)', + 'CREATE INDEX IF NOT EXISTS idx_parse_tasks_status ON lazyllm_parse_tasks(status, updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_parse_tasks_doc ON lazyllm_parse_tasks(doc_id, kb_id, algo_id)', + ] + for stmt in stmts: + self._db_manager.execute_commit(stmt) + + def _upsert_default_algorithm(self): + default_group_info = [ + {'name': 'CoarseChunk', 'type': 'chunk', 'display_name': 'Coarse Chunk'}, + {'name': 'line', 'type': 'chunk', 'display_name': 'Line Chunk'}, + ] + default_info = { + 'store': None, + 'reader': None, + 'node_groups': { + item['name']: {'group_type': item['type'], 'display_name': item['display_name']} + for item in default_group_info + }, + 'schema_extractor': None, + } + with self._db_manager.get_session() as session: + Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) + row = session.query(Algo).filter(Algo.id == '__default__').first() + if row is None: + session.add( + Algo( + id='__default__', + display_name='Default', + description='Default mock parsing algorithm', + info_pickle=cloudpickle.dumps(default_info), + created_at=now_ts(), + updated_at=now_ts(), + ) + ) + + def register_callback(self, callback_func: Callable[[TaskCallbackRequest], None]): + self._callback_func = callback_func + + def _load_task(self, task_id: str): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + return _orm_to_dict(row) if row else None + + def _emit_callback(self, callback_payload: TaskCallbackRequest, callback_url: Optional[str]): + if self._callback_func: + self._callback_func(callback_payload) + return + if callback_url: + response = requests.post(callback_url, json=callback_payload.model_dump(), timeout=5) + if response.status_code >= 400: + raise RuntimeError(f'callback failed: {response.status_code} {response.text}') + + def _update_task(self, task_id: str, **fields): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + if row is None: + return None + for key, value in fields.items(): + setattr(row, key, value) + row.updated_at = now_ts() + session.add(row) + return _orm_to_dict(row) + + def _task_worker(self): + while not self._shutdown: + try: + waiting_task = None + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) + waiting_task = ( + session.query(Task) + .filter(Task.status == DocStatus.WAITING.value) + .order_by(Task.priority.desc(), Task.created_at.asc()) + .first() + ) + if waiting_task is not None: + waiting_task.status = DocStatus.WORKING.value + waiting_task.started_at = now_ts() + waiting_task.updated_at = now_ts() + session.add(waiting_task) + waiting_task = _orm_to_dict(waiting_task) + if waiting_task is None: + time.sleep(self._poll_interval) + continue + + start_callback = TaskCallbackRequest( + task_id=waiting_task['task_id'], + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + payload={ + 'task_type': waiting_task['task_type'], + 'doc_id': waiting_task['doc_id'], + 'kb_id': waiting_task['kb_id'], + 'algo_id': waiting_task['algo_id'], + }, + ) + self._emit_callback(start_callback, waiting_task.get('callback_url')) + + # Mock workload + time.sleep(self._poll_interval) + final_status = ( + DocStatus.DELETED.value + if waiting_task['task_type'] == 'DOC_DELETE' + else DocStatus.SUCCESS.value + ) + done = self._update_task( + waiting_task['task_id'], + status=final_status, + finished_at=now_ts(), + error_code=None, + error_msg=None, + ) + finish_callback = TaskCallbackRequest( + task_id=waiting_task['task_id'], + event_type=CallbackEventType.FINISH, + status=DocStatus(final_status), + error_code=None, + error_msg=None, + payload={ + 'task_type': waiting_task['task_type'], + 'doc_id': waiting_task['doc_id'], + 'kb_id': waiting_task['kb_id'], + 'algo_id': waiting_task['algo_id'], + 'result': done, + }, + ) + self._emit_callback(finish_callback, waiting_task.get('callback_url')) + except Exception as exc: + LOG.error(f'[ParsingTaskServer] worker loop error: {exc}, {traceback.format_exc()}') + time.sleep(self._poll_interval) + + @app.post('/v1/internal/tasks/create') + def create_task(self, request: TaskCreateRequest): + self._lazy_init() + now = now_ts() + payload = request.model_dump(mode='json') + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) + exists = session.query(Task).filter(Task.task_id == request.task_id).first() + if exists is not None: + return BaseResponse(code=200, msg='success', data={'task': _orm_to_dict(exists), 'deduped': True}) + session.add( + Task( + task_id=request.task_id, + task_type=request.task_type.value, + doc_id=request.doc_id, + kb_id=request.kb_id, + algo_id=request.algo_id, + status=DocStatus.WAITING.value, + priority=request.priority, + message=json.dumps(payload, ensure_ascii=False), + callback_url=request.callback_url, + error_code=None, + error_msg=None, + created_at=now, + updated_at=now, + started_at=None, + finished_at=None, + ) + ) + task = self._load_task(request.task_id) + return BaseResponse(code=200, msg='success', data={'task': task, 'deduped': False}) + + @app.post('/v1/internal/tasks/cancel') + def cancel_task(self, request: Dict[str, str]): + self._lazy_init() + task_id = request.get('task_id') + if not task_id: + raise fastapi.HTTPException(status_code=400, detail='task_id is required') + task = self._load_task(task_id) + if not task: + return BaseResponse(code=404, msg='task not found', data={'task_id': task_id, 'cancel_status': False}) + if task['status'] != DocStatus.WAITING.value: + return BaseResponse( + code=409, + msg='task cannot be canceled', + data={'task_id': task_id, 'cancel_status': False, 'status': task['status']}, + ) + task = self._update_task(task_id, status=DocStatus.CANCELED.value, finished_at=now_ts()) + callback = TaskCallbackRequest( + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.CANCELED, + payload={ + 'task_type': task.get('task_type'), + 'doc_id': task.get('doc_id'), + 'kb_id': task.get('kb_id'), + 'algo_id': task.get('algo_id'), + }, + ) + try: + self._emit_callback(callback, task.get('callback_url')) + except Exception as exc: + LOG.warning(f'[ParsingTaskServer] cancel callback failed: {exc}') + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'cancel_status': True, 'status': DocStatus.CANCELED.value}, + ) + + @app.get('/v1/tasks') + def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): + self._lazy_init() + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) + query = session.query(Task) + if status: + query = query.filter(Task.status.in_(status)) + total = query.count() + rows = ( + query.order_by(Task.created_at.desc()) + .offset(max(page - 1, 0) * page_size) + .limit(page_size) + .all() + ) + items = [_orm_to_dict(row) for row in rows] + return BaseResponse(code=200, msg='success', data={ + 'items': items, + 'total': total, + 'page': page, + 'page_size': page_size, + }) + + @app.get('/v1/tasks/{task_id}') + def get_task(self, task_id: str): + self._lazy_init() + task = self._load_task(task_id) + if task is None: + raise fastapi.HTTPException(status_code=404, detail='task not found') + return BaseResponse(code=200, msg='success', data=task) + + @app.get('/v1/algo/list') + def list_algorithms(self): + self._lazy_init() + with self._db_manager.get_session() as session: + Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) + rows = session.query(Algo).order_by(Algo.created_at.asc()).all() + data = [ + {'algo_id': row.id, 'display_name': row.display_name, 'description': row.description} + for row in rows + ] + return BaseResponse(code=200, msg='success', data=data) + + @app.get('/v1/algo/{algo_id}/groups') + def get_algorithm_groups(self, algo_id: str): + self._lazy_init() + with self._db_manager.get_session() as session: + Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) + row = session.query(Algo).filter(Algo.id == algo_id).first() + if row is None: + raise fastapi.HTTPException(status_code=404, detail='algo not found') + info = cloudpickle.loads(row.info_pickle) + node_groups = info.get('node_groups', {}) if isinstance(info, dict) else {} + data = [] + for name, group in node_groups.items(): + data.append({ + 'name': name, + 'type': group.get('group_type'), + 'display_name': group.get('display_name'), + }) + return BaseResponse(code=200, msg='success', data=data) + + @app.get('/v1/health') + def health(self): + self._lazy_init() + healthy = self._task_thread is not None and self._task_thread.is_alive() + return BaseResponse(code=200 if healthy else 503, msg='success' if healthy else 'unhealthy', data={ + 'status': 'ok' if healthy else 'degraded', + 'version': 'v1-mock', + 'deps': { + 'sql': bool(self._db_manager), + 'worker': healthy, + }, + }) + + def __init__( + self, + port: Optional[int] = None, + url: Optional[str] = None, + db_config: Optional[Dict[str, Any]] = None, + poll_interval: float = 0.05, + callback_func: Optional[Callable[[TaskCallbackRequest], None]] = None, + launcher=None, + ): + super().__init__() + self._raw_impl = None + self._db_config = db_config or _get_default_db_config('doc_service_parser') + if url: + self._impl = UrlModule(url=ensure_call_endpoint(url)) + else: + self._raw_impl = ParsingTaskServer._Impl( + db_config=self._db_config, + poll_interval=poll_interval, + callback_func=callback_func, + ) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + + def start(self): + result = super().start() + if self._raw_impl: + self._dispatch('_lazy_init') + return result + + def stop(self): + if self._raw_impl: + self._dispatch('stop') + if isinstance(self._impl, ServerModule): + self._impl.stop() + + def _dispatch(self, method: str, *args, **kwargs): + impl = self._impl + if isinstance(impl, ServerModule): + return impl._call(method, *args, **kwargs) + return getattr(impl, method)(*args, **kwargs) + + def register_callback(self, callback_func: Callable[[TaskCallbackRequest], None]): + if self._raw_impl: + self._raw_impl.register_callback(callback_func) + + def create_task(self, request: TaskCreateRequest): + return self._dispatch('create_task', request) + + def cancel_task(self, task_id: str): + return self._dispatch('cancel_task', {'task_id': task_id}) + + def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): + return self._dispatch('list_tasks', status, page, page_size) + + def get_task(self, task_id: str): + return self._dispatch('get_task', task_id) + + def list_algorithms(self): + return self._dispatch('list_algorithms') + + def get_algorithm_groups(self, algo_id: str): + return self._dispatch('get_algorithm_groups', algo_id) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index bdd304bfe..f35e23086 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -9,7 +9,7 @@ from lazyllm.tools.sql.sql_manager import SqlManager, DBStatus from lazyllm.common.bind import _MetaBind -from .doc_manager import DocManager +from .doc_service import DocServer from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType from .doc_node import DocNode from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor @@ -43,7 +43,8 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False, doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', - schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, + doc_server_port: Optional[int] = None): super().__init__() self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud @@ -69,7 +70,8 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, store=store_conf, processor=processor, algo_name=name, display_name=display_name, description=description, schema_extractor=schema_extractor)}) - if manager: self._manager = ServerModule(DocManager(self._dlm), launcher=self._launcher) + if manager: + self._manager = DocServer(launcher=self._launcher, storage_dir=dataset_path, port=doc_server_port) if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager) if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server))) self._global_metadata_desc = doc_fields @@ -99,7 +101,8 @@ def _register_submodules(self, m): if isinstance(embed, ModuleBase): self._submodules.append(embed) return m - def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, store_conf: Optional[Dict] = None, + def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, + store_conf: Optional[Dict] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): embed = self._get_embeds(embed) if embed else self._embed @@ -138,7 +141,8 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal doc_files: Optional[List[str]] = None, doc_fields: Dict[str, DocField] = None, store_conf: Optional[Dict] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', - schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, + doc_server_port: Optional[int] = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') @@ -179,7 +183,8 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf, doc_fields, cloud=cloud, doc_files=doc_files, processor=processor, display_name=display_name, description=description, - schema_extractor=schema_extractor) + schema_extractor=schema_extractor, + doc_server_port=doc_server_port) self._curr_group = name self._doc_to_db_processor: DocToDbProcessor = None self._graph_document: weakref.ref = None diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 8e11a0d0f..0c67a7c65 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -72,6 +72,7 @@ class TaskType(str, Enum): DOC_DELETE = 'DOC_DELETE' DOC_UPDATE_META = 'DOC_UPDATE_META' DOC_REPARSE = 'DOC_REPARSE' + DOC_TRANSFER = 'DOC_TRANSFER' def _get_task_type_weight(task_type: str) -> int: @@ -81,6 +82,7 @@ def _get_task_type_weight(task_type: str) -> int: TaskType.DOC_UPDATE_META.value: 30, TaskType.DOC_ADD.value: 100, TaskType.DOC_REPARSE.value: 100, + TaskType.DOC_TRANSFER.value: 100, } return weight_map.get(task_type, 100) diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index a0de7b1e5..6c94156e6 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -496,6 +496,12 @@ def start(self): raise return result + def wait(self): + impl = self._impl + if isinstance(impl, ServerModule): + return impl.wait() + LOG.warning('[DocumentProcessor] wait() is no-op in UrlModule mode') + def _dispatch(self, method: str, *args, **kwargs): try: impl = self._impl diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index ac768325e..4ce0c364a 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -278,6 +278,12 @@ def start(self): LOG.info('[DocumentProcessorWorker] Worker started') return result + def wait(self): + impl = self._worker_impl + if isinstance(impl, ServerModule): + return impl.wait() + LOG.warning('[DocumentProcessorWorker] wait() is no-op in local mode') + def stop(self): LOG.info('[DocumentProcessorWorker] Stopping worker...') self._dispatch('shutdown') diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index c6b0eebe9..f06674025 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -883,6 +883,13 @@ def _check_batch(fn): batch_size = getattr(fn, 'batch_size', None) return isinstance(batch_size, Integral) and batch_size > 1 + def _check_empty_embedding_item(vec, embed_key: str, idx: int) -> None: + if vec is None: + raise ValueError(f'[LazyLLM - parallel_do_embedding][{embed_key}] invalid embedding at index {idx}: None') + if isinstance(vec, (list, dict)) and len(vec) == 0: + raise ValueError(f'[LazyLLM - parallel_do_embedding][{embed_key}] ' + f'invalid embedding at index {idx}: empty {type(vec).__name__}') + def _process_key(k: str, knodes: List[DocNode]): try: fn = embed[k] @@ -892,7 +899,8 @@ def _process_key(k: str, knodes: List[DocNode]): if len(vecs) != len(texts): raise ValueError(f'[LazyLLM - parallel_do_embedding][{k}] batch size mismatch: ' f'[text_num:{len(texts)}] vs [vec_num:{len(vecs)}]') - for n, v in zip(knodes, vecs): + for idx, (n, v) in enumerate(zip(knodes, vecs)): + _check_empty_embedding_item(v, k, idx) n.set_embedding(k, v) return From 809821f9a712abdc87fecced174e41ec19b5bf68 Mon Sep 17 00:00:00 2001 From: chenjiahao Date: Tue, 10 Mar 2026 15:40:53 +0800 Subject: [PATCH 09/46] fix dbname init --- lazyllm/tools/sql/sql_manager.py | 63 +++++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index d0c7f53e1..addf48434 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -106,7 +106,8 @@ def _create_tables_by_info(self, tables_info: TablesInfo): column_name = column_info.name is_primary = column_info.is_primary_key default_value = column_info.default - real_type = self._sql_type_for(column_type) + # Use text for unsupported column type + real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text) # Handle default value if default_value is not None: attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, @@ -129,17 +130,58 @@ def _gen_desc_by_info(self, tables_info: TablesInfo) -> dict: desc_dict[table_info.name] = table_comment return desc_dict - def _gen_conn_url(self) -> str: + def _gen_conn_url(self, db_name: str = None) -> str: + db_name = self._db_name if db_name is None else db_name if self._db_type == 'sqlite': - conn_url = f'sqlite:///{self._db_name}{("?" + self._options_str) if self._options_str else ""}' + conn_url = f'sqlite:///{db_name}{("?" + self._options_str) if self._options_str else ""}' else: driver = self.DB_DRIVER_MAP.get(self._db_type if self._db_type != 'tidb' else 'mysql', '') password = quote_plus(self._password) prefix = 'mysql' if self._db_type == 'tidb' else self._db_type + db_path = f'/{db_name}' if db_name else '/' conn_url = (f'{prefix}{("+" + driver) if driver else ""}://{self._user}:{password}@{self._host}' - f':{self._port}/{self._db_name}{("?" + self._options_str) if self._options_str else ""}') + f':{self._port}{db_path}{("?" + self._options_str) if self._options_str else ""}') return conn_url + def _mysql_engine_kwargs(self) -> dict: + kwargs = { + 'pool_size': 10, + 'max_overflow': 20, + 'pool_pre_ping': True, + } + if self._db_type == 'tidb': + kwargs.update({'pool_recycle': 300, 'connect_args': {}, 'echo': False}) + else: + kwargs.update({'pool_recycle': 3600}) + return kwargs + + @staticmethod + def _get_operational_error_code(error: OperationalError): + args = getattr(getattr(error, 'orig', None), 'args', ()) + return args[0] if args else None + + def _ensure_database_exists(self, conn_url: str): + if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb'): + return + probe_engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + try: + with probe_engine.connect(): + return + except OperationalError as e: + if self._get_operational_error_code(e) != 1049: + raise + finally: + probe_engine.dispose() + + admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) + try: + escaped_db_name = self._db_name.replace('`', '``') + with admin_engine.connect() as conn: + conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) + conn.commit() + finally: + admin_engine.dispose() + @property def engine(self): if self._engine is None: @@ -156,16 +198,9 @@ def engine(self): conn.execute(sqlalchemy.text('PRAGMA synchronous=NORMAL')) conn.execute(sqlalchemy.text('PRAGMA busy_timeout=30000')) conn.commit() - elif self._db_type == 'tidb': - self._engine = sqlalchemy.create_engine( - conn_url, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=300, - connect_args={}, - echo=False, - ) + elif self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): + self._ensure_database_exists(conn_url) + self._engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) else: self._engine = sqlalchemy.create_engine( conn_url, From 61c5becb7d5da313beae0896b46511a95260bafc Mon Sep 17 00:00:00 2001 From: chenjiahao Date: Tue, 10 Mar 2026 16:07:04 +0800 Subject: [PATCH 10/46] fix tidb primary kay type --- lazyllm/tools/sql/sql_manager.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index addf48434..e37e438b9 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -87,9 +87,13 @@ def _init_tables_by_info(self, tables_info_dict): except pydantic.ValidationError as e: raise ValueError(f'Validate tables_info_dict failed: {str(e)}') - def _sql_type_for(self, py_type: str): + def _sql_type_for(self, py_type: str, *, is_primary_key: bool = False): t = py_type.lower() if self._db_type in ('mysql', 'tidb', 'mysql+pymysql'): + # MySQL/TiDB do not allow TEXT/BLOB columns to be used as primary keys + # without a prefix length. Use VARCHAR for identifier-like key columns. + if is_primary_key and t in ('string', 'text'): + return sqlalchemy.String(255) if t == 'list': return sqlalchemy.JSON if t == 'uuid': @@ -106,8 +110,8 @@ def _create_tables_by_info(self, tables_info: TablesInfo): column_name = column_info.name is_primary = column_info.is_primary_key default_value = column_info.default - # Use text for unsupported column type - real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text) + # Keep cross-db compatibility while handling MySQL/TiDB PK restrictions. + real_type = self._sql_type_for(column_type, is_primary_key=is_primary) # Handle default value if default_value is not None: attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, From cea23e314eb962dd86eb24f465b068704809fab5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 11 Mar 2026 18:48:04 +0800 Subject: [PATCH 11/46] temp --- .gitignore | 1 + examples/rag/doc_service_standalone.py | 411 +++++++- lazyllm/common/utils.py | 31 +- lazyllm/tools/rag/doc_service/__init__.py | 3 +- lazyllm/tools/rag/doc_service/base.py | 98 +- lazyllm/tools/rag/doc_service/doc_manager.py | 767 +++++++++++--- lazyllm/tools/rag/doc_service/doc_server.py | 220 +++- .../tools/rag/doc_service/parsing_server.py | 434 -------- lazyllm/tools/rag/parsing_service/base.py | 14 +- lazyllm/tools/rag/parsing_service/queue.py | 50 + lazyllm/tools/rag/parsing_service/server.py | 246 ++++- lazyllm/tools/rag/parsing_service/worker.py | 70 +- lazyllm/tools/rag/store/document_store.py | 19 +- tests/basic_tests/RAG/test_doc_processor.py | 2 +- .../basic_tests/RAG/test_doc_service_mock.py | 966 ++++++++++++++++++ tests/basic_tests/Tools/test_sql_manager.py | 64 ++ 16 files changed, 2722 insertions(+), 674 deletions(-) delete mode 100644 lazyllm/tools/rag/doc_service/parsing_server.py create mode 100644 tests/basic_tests/RAG/test_doc_service_mock.py create mode 100644 tests/basic_tests/Tools/test_sql_manager.py diff --git a/.gitignore b/.gitignore index 0e34cdd67..1b2ff5977 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ test/ dist/ tmp/ +tmp*/ build *.lock *.db diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py index 3e968f75c..99e1e38e0 100644 --- a/examples/rag/doc_service_standalone.py +++ b/examples/rag/doc_service_standalone.py @@ -1,67 +1,428 @@ -'''Start standalone DocService mock server. +'''Start standalone DocService. + +Modes: +1. Full stack mode (default): starts algorithm registration, real + DocumentProcessor, and DocServer in one process. +2. External parser mode: starts only DocServer and connects to an existing + parsing service with ``--parser-url``. Run: python examples/rag/doc_service_standalone.py --wait + python examples/rag/doc_service_standalone.py --parser-url http://127.0.0.1:9966 --wait ''' from __future__ import annotations import argparse +import json import os import tempfile +import threading import time +from datetime import datetime +from typing import Any, Dict, Optional +from uuid import uuid4 +import requests -def main(): - parser = argparse.ArgumentParser(description='Standalone DocService mock server.') - parser.add_argument('--port', type=int, default=None, help='DocServer listen port.') - parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API/docs inspection.') - args = parser.parse_args() +import lazyllm +from lazyllm import Document +from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.base import CallbackEventType, DocStatus, TaskCallbackRequest, TaskCreateRequest +from lazyllm.tools.rag.parsing_service import DocumentProcessor +from lazyllm.tools.rag.parsing_service.base import TaskStatus, TaskType +from lazyllm.tools.rag.utils import BaseResponse - from lazyllm.tools.rag.doc_service import DocServer +REAL_ALGO_ID = 'real-standalone-algo' - tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') - storage_dir = os.path.join(tmp_dir, 'uploads') - os.makedirs(storage_dir, exist_ok=True) - db_config = { + +def _make_db_config(db_name: str) -> Dict[str, Any]: + return { 'db_type': 'sqlite', 'user': None, 'password': None, 'host': None, 'port': None, - 'db_name': os.path.join(tmp_dir, 'doc_service.db'), + 'db_name': db_name, } - parser_db_config = { - 'db_type': 'sqlite', - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'db_name': os.path.join(tmp_dir, 'doc_service_parser.db'), + + +def _wait_until(predicate, timeout: float = 20.0, interval: float = 0.1): + deadline = time.time() + timeout + last = None + while time.time() < deadline: + last = predicate() + if last: + return last + time.sleep(interval) + raise RuntimeError(f'condition not satisfied before timeout, last={last!r}') + + +def _wait_http_ok(url: str, timeout: float = 20.0): + def _poll(): + try: + resp = requests.get(url, timeout=3) + if resp.status_code == 200: + return resp + except Exception: + return None + return None + + return _wait_until(_poll, timeout=timeout) + + +def _task_status_to_doc_status(task_status: str) -> DocStatus: + mapping = { + TaskStatus.SUCCESS.value: DocStatus.SUCCESS, + TaskStatus.FAILED.value: DocStatus.FAILED, + TaskStatus.CANCELED.value: DocStatus.CANCELED, + } + if task_status not in mapping: + raise RuntimeError(f'unsupported task status: {task_status}') + return mapping[task_status] + + +class _RealProcessorTaskAdapter: + def __init__(self, parser_base_url: str, manager, upstream_client): + self._parser_base_url = parser_base_url.rstrip('/') + self._manager = manager + self._upstream_client = upstream_client + self._tasks: Dict[str, Dict[str, Any]] = {} + self._lock = threading.Lock() + + def _load_doc(self, doc_id: str) -> Dict[str, Any]: + doc = self._manager._get_doc(doc_id) + if doc is None: + raise RuntimeError(f'doc not found: {doc_id}') + return doc + + @staticmethod + def _load_metadata(doc: Dict[str, Any]) -> Dict[str, Any]: + raw = doc.get('meta') + return json.loads(raw) if raw else {} + + def _record_task(self, req: TaskCreateRequest) -> Dict[str, Any]: + now = datetime.now().isoformat() + task = { + 'task_id': req.task_id, + 'task_type': req.task_type.value, + 'doc_id': req.doc_id, + 'kb_id': req.kb_id, + 'algo_id': req.algo_id, + 'status': TaskStatus.WAITING.value, + 'priority': req.priority, + 'callback_url': req.callback_url, + 'error_code': None, + 'error_msg': None, + 'created_at': now, + 'updated_at': now, + 'started_at': None, + 'finished_at': None, + } + with self._lock: + self._tasks[req.task_id] = task + return task + + def _dispatch_add_like_task(self, req: TaskCreateRequest, doc: Dict[str, Any], reparse: bool = False): + file_info = { + 'file_path': doc['path'], + 'doc_id': req.doc_id, + 'metadata': self._load_metadata(doc), + } + if reparse: + file_info['reparse_group'] = 'CoarseChunk' + payload = { + 'task_id': req.task_id, + 'algo_id': req.algo_id, + 'kb_id': req.kb_id, + 'file_infos': [file_info], + 'priority': req.priority, + } + return requests.post(f'{self._parser_base_url}/doc/add', json=payload, timeout=15) + + def create_task(self, req: TaskCreateRequest): + self._record_task(req) + try: + if req.task_type == TaskType.DOC_ADD: + resp = self._dispatch_add_like_task(req, self._load_doc(req.doc_id)) + elif req.task_type == TaskType.DOC_REPARSE: + resp = self._dispatch_add_like_task(req, self._load_doc(req.doc_id), reparse=True) + elif req.task_type == TaskType.DOC_DELETE: + resp = requests.delete( + f'{self._parser_base_url}/doc/delete', + json={ + 'task_id': req.task_id, + 'algo_id': req.algo_id, + 'kb_id': req.kb_id, + 'doc_ids': [req.doc_id], + 'priority': req.priority, + }, + timeout=15, + ) + elif req.task_type == TaskType.DOC_UPDATE_META: + doc = self._load_doc(req.doc_id) + resp = requests.post( + f'{self._parser_base_url}/doc/meta/update', + json={ + 'task_id': req.task_id, + 'algo_id': req.algo_id, + 'kb_id': req.kb_id, + 'file_infos': [{ + 'file_path': doc['path'], + 'doc_id': req.doc_id, + 'metadata': req.metadata, + }], + 'priority': req.priority, + }, + timeout=15, + ) + else: + raise RuntimeError(f'unsupported task type: {req.task_type.value}') + if resp.status_code >= 400: + raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') + result = BaseResponse.model_validate(resp.json()) + if result.code != 200: + raise RuntimeError(f'parser task rejected: {result.msg}') + return result + except Exception: + with self._lock: + self._tasks.pop(req.task_id, None) + raise + + def mark_task_finished(self, task_id: str, task_status: str, + error_code: Optional[str] = None, error_msg: Optional[str] = None): + with self._lock: + task = self._tasks.get(task_id) + if task is None: + return None + finished_at = datetime.now().isoformat() + task['status'] = task_status + task['error_code'] = error_code + task['error_msg'] = error_msg + task['finished_at'] = finished_at + task['updated_at'] = finished_at + return dict(task) + + def cancel_task(self, task_id: str): + resp = requests.post(f'{self._parser_base_url}/doc/cancel', json={'task_id': task_id}, timeout=8) + if resp.status_code >= 400: + raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') + result = BaseResponse.model_validate(resp.json()) + if result.code == 200 and result.data and result.data.get('cancel_status'): + self.mark_task_finished(task_id, TaskStatus.CANCELED.value) + return result + + def list_tasks(self, status: Optional[list[str]], page: int, page_size: int): + with self._lock: + items = [dict(task) for task in self._tasks.values()] + if status: + items = [item for item in items if item['status'] in status] + items.sort(key=lambda item: item['created_at'], reverse=True) + total = len(items) + sliced = items[(page - 1) * page_size:page * page_size] + return BaseResponse( + code=200, + msg='success', + data={'items': sliced, 'total': total, 'page': page, 'page_size': page_size}, + ) + + def get_task(self, task_id: str): + with self._lock: + task = self._tasks.get(task_id) + if task is None: + return BaseResponse(code=404, msg='task not found', data=None) + return BaseResponse(code=200, msg='success', data=dict(task)) + + def list_algorithms(self): + return self._upstream_client.list_algorithms() + + def get_algorithm_groups(self, algo_id: str): + return self._upstream_client.get_algorithm_groups(algo_id) + + +def _make_post_func(state: Dict[str, Any]): + def _post_func(task_id: str, task_status: str, error_code: str = None, error_msg: str = None): + adapter = state['adapter'] + callback_url = state['callback_url'] + task = adapter.mark_task_finished(task_id, task_status, error_code, error_msg) + if task is None: + raise RuntimeError(f'untracked callback task: {task_id}') + callback = TaskCallbackRequest( + callback_id=str(uuid4()), + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=_task_status_to_doc_status(task_status), + error_code=error_code, + error_msg=error_msg, + payload={ + 'task_type': task['task_type'], + 'doc_id': task['doc_id'], + 'kb_id': task['kb_id'], + 'algo_id': task['algo_id'], + }, + ) + resp = requests.post(callback_url, json=callback.model_dump(mode='json'), timeout=8) + resp.raise_for_status() + return True + + return _post_func + + +def _build_store_conf(root_dir: str) -> Dict[str, Any]: + segment_store_path = os.path.join(root_dir, 'segments.db') + milvus_store_path = os.path.join(root_dir, 'milvus_lite.db') + open(segment_store_path, 'a', encoding='utf-8').close() + open(milvus_store_path, 'a', encoding='utf-8').close() + return { + 'segment_store': { + 'type': 'map', + 'kwargs': {'uri': segment_store_path}, + }, + 'vector_store': { + 'type': 'milvus', + 'kwargs': { + 'uri': milvus_store_path, + 'index_kwargs': { + 'index_type': 'FLAT', + 'metric_type': 'COSINE', + }, + }, + }, } + +def _start_full_stack(args): + tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') + storage_dir = os.path.join(tmp_dir, 'uploads') + os.makedirs(storage_dir, exist_ok=True) + parser_db = os.path.join(tmp_dir, 'parser.db') + doc_db = os.path.join(tmp_dir, 'doc_service.db') + callback_state: Dict[str, Any] = {} + + parser = DocumentProcessor( + port=args.parser_port, + db_config=_make_db_config(parser_db), + num_workers=args.num_workers, + post_func=_make_post_func(callback_state), + ) + parser.start() + parser_base_url = parser._impl._url.rsplit('/', 1)[0] + _wait_http_ok(f'{parser_base_url}/health') + + store_conf = _build_store_conf(tmp_dir) + document = Document( + dataset_path=None, + name=args.algo_id, + embed={'vec_dense': lambda x: [1.0, 2.0, 3.0]}, + store_conf=store_conf, + display_name='Standalone Real Algo', + manager=DocumentProcessor(url=parser_base_url), + description='Algorithm registered by standalone doc service example', + ) + document.create_node_group( + name='line', + transform=lambda x: x.split('\n'), + parent='CoarseChunk', + display_name='Line Chunk', + ) + document.activate_group('CoarseChunk', embed_keys=['vec_dense']) + document.activate_group('line', embed_keys=['vec_dense']) + document.start() + + _wait_until( + lambda: any( + item.get('algo_id') == args.algo_id + for item in requests.get(f'{parser_base_url}/algo/list', timeout=5).json().get('data', []) + ) + ) + + server = DocServer( + storage_dir=storage_dir, + db_config=_make_db_config(doc_db), + parser_url=parser_base_url, + port=args.port, + ) + server.start() + base_url = server.url.rsplit('/', 1)[0] + _wait_http_ok(f'{base_url}/v1/health') + + raw_impl = server._raw_impl + raw_impl._lazy_init() + adapter = _RealProcessorTaskAdapter( + parser_base_url=parser_base_url, + manager=raw_impl._manager, + upstream_client=raw_impl._manager._parser_client, + ) + raw_impl._manager._parser_client = adapter + callback_state['adapter'] = adapter + callback_state['callback_url'] = raw_impl._manager._callback_url + + print(f'DocService URL: {base_url}', flush=True) + print(f'DocService Docs: {base_url}/docs', flush=True) + print(f'Parser URL: {parser_base_url}', flush=True) + print(f'Parser Docs: {parser_base_url}/docs', flush=True) + print(f'Algorithm ID: {args.algo_id}', flush=True) + print(f'Storage Dir: {storage_dir}', flush=True) + print(f'Doc DB: {doc_db}', flush=True) + print(f'Parser DB: {parser_db}', flush=True) + print(f'Tmp Dir: {tmp_dir}', flush=True) + + try: + if args.wait: + print('Full stack is running. Press Ctrl+C to stop...', flush=True) + threading.Event().wait() + finally: + server.stop() + try: + parser.drop_algorithm(args.algo_id) + except Exception: + pass + parser.stop() + + +def _start_doc_server_only(args): + tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') + storage_dir = os.path.join(tmp_dir, 'uploads') + os.makedirs(storage_dir, exist_ok=True) + doc_db = os.path.join(tmp_dir, 'doc_service.db') server = DocServer( storage_dir=storage_dir, - db_config=db_config, - parser_db_config=parser_db_config, + db_config=_make_db_config(doc_db), + parser_url=args.parser_url, port=args.port, ) server.start() base_url = server.url.rsplit('/', 1)[0] print(f'DocService URL: {base_url}', flush=True) - print(f'Swagger Docs: {base_url}/docs', flush=True) + print(f'DocService Docs: {base_url}/docs', flush=True) + print(f'Parser URL: {args.parser_url}', flush=True) print(f'Storage Dir: {storage_dir}', flush=True) - print(f'Doc DB: {db_config["db_name"]}', flush=True) - print(f'Parser DB: {parser_db_config["db_name"]}', flush=True) + print(f'Doc DB: {doc_db}', flush=True) try: if args.wait: - print('Server is running. Press Ctrl+C to stop...', flush=True) + print('DocService is running. Press Ctrl+C to stop...', flush=True) while True: time.sleep(1) finally: server.stop() +def main(): + parser = argparse.ArgumentParser(description='Standalone DocService server.') + parser.add_argument('--port', type=int, default=8848, help='DocServer listen port.') + parser.add_argument('--parser-port', type=int, default=9966, help='DocumentProcessor listen port.') + parser.add_argument('--parser-url', type=str, default=None, help='Existing parsing service base URL.') + parser.add_argument('--algo-id', type=str, default=REAL_ALGO_ID, help='Algorithm id to register in full stack mode.') + parser.add_argument('--num-workers', type=int, default=1, help='DocumentProcessor worker count.') + parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API inspection.') + args = parser.parse_args() + + if args.parser_url: + _start_doc_server_only(args) + else: + _start_full_stack(args) + + if __name__ == '__main__': main() diff --git a/lazyllm/common/utils.py b/lazyllm/common/utils.py index bfed8e206..68745f06e 100644 --- a/lazyllm/common/utils.py +++ b/lazyllm/common/utils.py @@ -3,6 +3,7 @@ from typing import Union, Dict, Callable, Any, Optional import re import os +import sys from contextlib import contextmanager import cloudpickle import ast @@ -161,11 +162,37 @@ def str2bool(v: str) -> bool: raise argparse.ArgumentTypeError('Boolean value expected.') def dump_obj(f): + def _collect_test_modules(obj): + modules = [] + seen = set() + candidates = [obj] + if hasattr(obj, '__dict__'): + candidates.extend(obj.__dict__.values()) + for candidate in candidates: + module_name = getattr(candidate, '__module__', None) + if not module_name: + continue + if not (module_name.startswith('test_') or module_name.startswith('tmp.tests.')): + continue + module = sys.modules.get(module_name) + if module is None or module_name in seen: + continue + seen.add(module_name) + modules.append(module) + return modules + @contextmanager def env_helper(): os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'ON' - yield - os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'OFF' + modules = _collect_test_modules(f) + for module in modules: + cloudpickle.register_pickle_by_value(module) + try: + yield + finally: + for module in modules: + cloudpickle.unregister_pickle_by_value(module) + os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'OFF' with env_helper(): return None if f is None else base64.b64encode(cloudpickle.dumps(f)).decode('utf-8') diff --git a/lazyllm/tools/rag/doc_service/__init__.py b/lazyllm/tools/rag/doc_service/__init__.py index 94dc21ff8..9f3ef78e9 100644 --- a/lazyllm/tools/rag/doc_service/__init__.py +++ b/lazyllm/tools/rag/doc_service/__init__.py @@ -1,5 +1,4 @@ from .doc_server import DocServer from .doc_manager import DocManager -from .parsing_server import ParsingTaskServer -__all__ = ['DocServer', 'DocManager', 'ParsingTaskServer'] +__all__ = ['DocServer', 'DocManager'] diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py index a69486db8..6f6e1e13f 100644 --- a/lazyllm/tools/rag/doc_service/base.py +++ b/lazyllm/tools/rag/doc_service/base.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional from uuid import uuid4 -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from ..parsing_service.base import TaskType @@ -84,6 +84,12 @@ class AddFileItem(BaseModel): doc_id: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) + @model_validator(mode='after') + def validate_file_path(self): + if not self.file_path or not self.file_path.strip(): + raise ValueError('file_path is required') + return self + class AddRequest(BaseModel): items: List[AddFileItem] @@ -92,6 +98,12 @@ class AddRequest(BaseModel): source_type: SourceType = SourceType.EXTERNAL idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + class UploadRequest(BaseModel): items: List[AddFileItem] @@ -100,6 +112,12 @@ class UploadRequest(BaseModel): source_type: SourceType = SourceType.API idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + class ReparseRequest(BaseModel): doc_ids: List[str] @@ -107,6 +125,12 @@ class ReparseRequest(BaseModel): algo_id: str = '__default__' idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_doc_ids(self): + if not self.doc_ids: + raise ValueError('doc_ids is required') + return self + class DeleteRequest(BaseModel): doc_ids: List[str] @@ -114,6 +138,12 @@ class DeleteRequest(BaseModel): algo_id: str = '__default__' idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_doc_ids(self): + if not self.doc_ids: + raise ValueError('doc_ids is required') + return self + class TransferItem(BaseModel): doc_id: str @@ -128,6 +158,12 @@ class TransferRequest(BaseModel): items: List[TransferItem] idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + class MetadataPatchItem(BaseModel): doc_id: str @@ -140,6 +176,41 @@ class MetadataPatchRequest(BaseModel): algo_id: str = '__default__' idempotency_key: Optional[str] = None + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + + +class KbCreateRequest(BaseModel): + kb_id: str + display_name: Optional[str] = None + description: Optional[str] = None + owner_id: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + +class KbUpdateRequest(BaseModel): + display_name: Optional[str] = None + description: Optional[str] = None + owner_id: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + algo_id: Optional[str] = None + idempotency_key: Optional[str] = None + + +class KbBatchQueryRequest(BaseModel): + kb_ids: List[str] + + @model_validator(mode='after') + def validate_kb_ids(self): + if not self.kb_ids: + raise ValueError('kb_ids is required') + return self + IDEMPOTENCY_RECORDS_TABLE_INFO = { 'name': 'lazyllm_idempotency_records', @@ -174,6 +245,31 @@ class MetadataPatchRequest(BaseModel): } +DOC_SERVICE_TASKS_TABLE_INFO = { + 'name': 'lazyllm_doc_service_tasks', + 'comment': 'Doc service task state table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, + {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Current task status'}, + {'name': 'message', 'data_type': 'text', 'nullable': True, 'comment': 'Task payload in JSON string'}, + {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Error code'}, + {'name': 'error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Error message'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, + {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, + ], +} + + DOCUMENTS_TABLE_INFO = { 'name': 'lazyllm_documents', 'comment': 'Document metadata table', diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index b597b9236..a4562b3bd 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -4,7 +4,7 @@ import json import os from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set from uuid import uuid4 import requests @@ -20,6 +20,7 @@ CallbackEventType, CALLBACK_RECORDS_TABLE_INFO, DeleteRequest, + DOC_SERVICE_TASKS_TABLE_INFO, DocServiceError, DOCUMENTS_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, @@ -39,7 +40,14 @@ DocStatus, now_ts, ) -from .parsing_server import ParsingTaskServer +from ..parsing_service.base import ( + AddDocRequest as ParsingAddDocRequest, + CancelTaskRequest as ParsingCancelTaskRequest, + DeleteDocRequest as ParsingDeleteDocRequest, + FileInfo as ParsingFileInfo, + TaskStatus, + UpdateMetaRequest as ParsingUpdateMetaRequest, +) def _to_json(data: Optional[Dict[str, Any]]) -> str: @@ -78,15 +86,11 @@ def _sha256_file(file_path: str) -> str: class _ParserClient: - def __init__(self, parser_server: Optional[ParsingTaskServer] = None, parser_url: Optional[str] = None): - self._parser_server = parser_server - if parser_url: - parser_url = parser_url.rstrip('/') - if parser_url.endswith('/_call') or parser_url.endswith('/generate'): - parser_url = parser_url.rsplit('/', 1)[0] - self._parser_url = parser_url - else: - self._parser_url = None + def __init__(self, parser_url: str): + parser_url = parser_url.rstrip('/') + if parser_url.endswith('/_call') or parser_url.endswith('/generate'): + parser_url = parser_url.rsplit('/', 1)[0] + self._parser_url = parser_url def _post(self, path: str, payload: Dict[str, Any]): url = f'{self._parser_url}{path}' @@ -102,52 +106,87 @@ def _get(self, path: str, params: Optional[Dict[str, Any]] = None): raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') return resp.json() - def create_task(self, req: TaskCreateRequest): - if self._parser_server: - return self._parser_server.create_task(req) - data = self._post('/v1/internal/tasks/create', req.model_dump(mode='json')) + def _delete(self, path: str, payload: Optional[Dict[str, Any]] = None): + url = f'{self._parser_url}{path}' + resp = requests.delete(url, json=payload, timeout=8) + if resp.status_code >= 400: + raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') + return resp.json() + + def _get_with_fallback(self, paths: List[str], params: Optional[Dict[str, Any]] = None): + last_error = None + for path in paths: + try: + return self._get(path, params=params) + except RuntimeError as exc: + last_error = exc + if '404' not in str(exc): + raise + if last_error is not None: + raise last_error + raise RuntimeError('parser http error: no path provided') + + def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path: str, + metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, + callback_url: Optional[str] = None): + req = ParsingAddDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo( + file_path=file_path, + doc_id=doc_id, + metadata=metadata or {}, + reparse_group=reparse_group, + )], + ) + data = self._post('/doc/add', req.model_dump(mode='json')) return BaseResponse.model_validate(data) - def cancel_task(self, task_id: str): - if self._parser_server: - return self._parser_server.cancel_task(task_id) - data = self._post('/v1/internal/tasks/cancel', {'task_id': task_id}) + def update_meta(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + metadata: Optional[Dict[str, Any]] = None, file_path: Optional[str] = None, + callback_url: Optional[str] = None): + req = ParsingUpdateMetaRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo(file_path=file_path, doc_id=doc_id, metadata=metadata or {})], + ) + data = self._post('/doc/meta/update', req.model_dump(mode='json')) return BaseResponse.model_validate(data) - def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): - if self._parser_server: - return self._parser_server.list_tasks(status=status, page=page, page_size=page_size) - params: Dict[str, Any] = {'page': page, 'page_size': page_size} - if status: - params['status'] = status - data = self._get('/v1/tasks', params=params) + def delete_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + callback_url: Optional[str] = None): + req = ParsingDeleteDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + doc_ids=[doc_id], + callback_url=callback_url, + feedback_url=callback_url, + ) + data = self._delete('/doc/delete', req.model_dump(mode='json')) return BaseResponse.model_validate(data) - def get_task(self, task_id: str): - if self._parser_server: - try: - return self._parser_server.get_task(task_id) - except (fastapi.HTTPException, requests.RequestException): - return BaseResponse(code=404, msg='task not found', data=None) - try: - data = self._get(f'/v1/tasks/{task_id}') - return BaseResponse.model_validate(data) - except RuntimeError as exc: - if '404' in str(exc): - return BaseResponse(code=404, msg='task not found', data=None) - raise + def cancel_task(self, task_id: str): + req = ParsingCancelTaskRequest(task_id=task_id) + data = self._post('/doc/cancel', req.model_dump(mode='json')) + return BaseResponse.model_validate(data) def list_algorithms(self): - if self._parser_server: - return self._parser_server.list_algorithms() - data = self._get('/v1/algo/list') + data = self._get_with_fallback(['/v1/algo/list', '/algo/list']) return BaseResponse.model_validate(data) def get_algorithm_groups(self, algo_id: str): - if self._parser_server: - return self._parser_server.get_algorithm_groups(algo_id) try: - data = self._get(f'/v1/algo/{algo_id}/groups') + data = self._get_with_fallback([ + f'/v1/algo/{algo_id}/groups', + f'/algo/{algo_id}/group/info', + ]) return BaseResponse.model_validate(data) except RuntimeError as exc: if '404' in str(exc): @@ -159,22 +198,22 @@ class DocManager: def __init__( self, db_config: Optional[Dict[str, Any]] = None, - parser_server: Optional[ParsingTaskServer] = None, parser_url: Optional[str] = None, callback_url: Optional[str] = None, ): - if parser_server is None and not parser_url: - raise ValueError('Either parser_server or parser_url must be provided') + if not parser_url: + raise ValueError('parser_url is required') self._db_config = db_config or _get_default_db_config('doc_service') self._db_manager = SqlManager( **self._db_config, tables_info_dict={'tables': [DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, PARSE_STATE_TABLE_INFO, - IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO]}, + IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO, + DOC_SERVICE_TASKS_TABLE_INFO]}, ) self._ensure_indexes() - self._parser_client = _ParserClient(parser_server=parser_server, parser_url=parser_url) + self._parser_client = _ParserClient(parser_url=parser_url) self._callback_url = callback_url self._upsert_default_kb() @@ -212,6 +251,12 @@ def _ensure_indexes(self): 'ON lazyllm_idempotency_records(endpoint, idempotency_key)', 'CREATE UNIQUE INDEX IF NOT EXISTS uq_callback_id ' 'ON lazyllm_callback_records(callback_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_doc_service_task_id ' + 'ON lazyllm_doc_service_tasks(task_id)', + 'CREATE INDEX IF NOT EXISTS idx_doc_service_task_status ' + 'ON lazyllm_doc_service_tasks(status, updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_doc_service_task_doc ' + 'ON lazyllm_doc_service_tasks(doc_id, kb_id, algo_id)', ] for stmt in stmts: self._db_manager.execute_commit(stmt) @@ -222,7 +267,8 @@ def _upsert_default_kb(self): self._cleanup_idempotency_records() def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, - owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None): + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + update_fields: Optional[Set[str]] = None): now = now_ts() with self._db_manager.get_session() as session: Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) @@ -240,13 +286,15 @@ def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description updated_at=now, ) else: - if display_name is not None: + if update_fields is None: + update_fields = set() + if 'display_name' in update_fields: row.display_name = display_name - if description is not None: + if 'description' in update_fields: row.description = description - if owner_id is not None: + if 'owner_id' in update_fields: row.owner_id = owner_id - if meta is not None: + if 'meta' in update_fields: row.meta = _to_json(meta) if row.status == KBStatus.DELETED.value: row.status = KBStatus.ACTIVE.value @@ -281,6 +329,21 @@ def _get_kb_algorithm(self, kb_id: str): row = session.query(Rel).filter(Rel.kb_id == kb_id).first() return _orm_to_dict(row) if row else None + @staticmethod + def _build_kb_data(kb_row, algo_row=None): + return { + 'kb_id': kb_row.kb_id, + 'display_name': kb_row.display_name, + 'description': kb_row.description, + 'doc_count': kb_row.doc_count, + 'status': kb_row.status, + 'owner_id': kb_row.owner_id, + 'meta': _from_json(kb_row.meta), + 'algo_id': algo_row.algo_id if algo_row is not None else None, + 'created_at': kb_row.created_at, + 'updated_at': kb_row.updated_at, + } + def _validate_kb_algorithm(self, kb_id: str, algo_id: str): kb = self._get_kb(kb_id) if kb is None: @@ -297,6 +360,12 @@ def _validate_kb_algorithm(self, kb_id: str, algo_id: str): ) return binding + def _ensure_algorithm_exists(self, algo_id: str): + algorithms = self.list_algorithms() + if any(item.get('algo_id') == algo_id for item in algorithms): + return + raise DocServiceError('E_INVALID_PARAM', f'invalid algo_id: {algo_id}', {'algo_id': algo_id}) + def _ensure_kb_document(self, kb_id: str, doc_id: str): now = now_ts() created = False @@ -418,6 +487,50 @@ def _record_callback(self, callback_id: str, task_id: str): session.rollback() return False + def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb_id: str, algo_id: str, + status: DocStatus, message: Optional[Dict[str, Any]] = None): + now = now_ts() + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + session.add(Task( + task_id=task_id, + task_type=task_type.value, + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=status.value, + message=_to_json(message), + error_code=None, + error_msg=None, + created_at=now, + updated_at=now, + started_at=None, + finished_at=None, + )) + return self._get_task_record(task_id) + + def _get_task_record(self, task_id: str): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + if row is None: + return None + task = _orm_to_dict(row) + task['message'] = _from_json(task.get('message')) + return task + + def _update_task_record(self, task_id: str, **fields): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + if row is None: + return None + for key, value in fields.items(): + setattr(row, key, value) + row.updated_at = now_ts() + session.add(row) + return self._get_task_record(task_id) + def _refresh_kb_doc_count(self, kb_id: str): with self._db_manager.get_session() as session: Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) @@ -454,6 +567,7 @@ def _upsert_doc( path: str, metadata: Dict[str, Any], source_type: SourceType, + upload_status: DocStatus = DocStatus.SUCCESS, ): now = now_ts() file_type = os.path.splitext(path)[1].lstrip('.').lower() or None @@ -468,7 +582,7 @@ def _upsert_doc( filename=filename, path=path, meta=_to_json(metadata), - upload_status=DocStatus.SUCCESS.value, + upload_status=upload_status.value, source_type=source_type.value, file_type=file_type, content_hash=content_hash, @@ -480,7 +594,7 @@ def _upsert_doc( row.filename = filename row.path = path row.meta = _to_json(metadata) - row.upload_status = DocStatus.SUCCESS.value + row.upload_status = upload_status.value row.source_type = source_type.value row.file_type = file_type row.content_hash = content_hash @@ -489,6 +603,17 @@ def _upsert_doc( session.add(row) return self._get_doc(doc_id) + def _set_doc_upload_status(self, doc_id: str, status: DocStatus): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return None + row.upload_status = status.value + row.updated_at = now_ts() + session.add(row) + return self._get_doc(doc_id) + def _get_parse_snapshot(self, doc_id: str, kb_id: str, algo_id: str): with self._db_manager.get_session() as session: State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) @@ -596,17 +721,61 @@ def _upsert_parse_snapshot( session.add(row) return self._get_parse_snapshot(doc_id, kb_id, algo_id) - def _create_parser_task(self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType): - task_id = str(uuid4()) - req = TaskCreateRequest( - task_id=task_id, - task_type=task_type, - doc_id=doc_id, - kb_id=kb_id, - algo_id=algo_id, - callback_url=self._callback_url, - ) - task_resp = self._parser_client.create_task(req) + def _validate_unique_doc_ids(self, doc_ids: List[str], field_name: str = 'doc_id'): + duplicated = set() + seen = set() + for doc_id in doc_ids: + if doc_id in seen: + duplicated.add(doc_id) + seen.add(doc_id) + if duplicated: + duplicated_list = sorted(duplicated) + raise DocServiceError( + 'E_INVALID_PARAM', + f'duplicate {field_name} detected', + {f'duplicate_{field_name}s': duplicated_list}, + ) + + def _call_parser_client(self, method, *args, **kwargs): + try: + return method(*args, **kwargs) + except TypeError as exc: + if 'callback_url' not in kwargs or 'callback_url' not in str(exc): + raise + compat_kwargs = dict(kwargs) + compat_kwargs.pop('callback_url', None) + return method(*args, **compat_kwargs) + + def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, + file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + reparse_group: Optional[str] = None): + if task_type in (TaskType.DOC_ADD, TaskType.DOC_TRANSFER): + if not file_path: + raise RuntimeError(f'file_path is required for task_type {task_type.value}') + task_resp = self._call_parser_client( + self._parser_client.add_doc, + task_id, algo_id, kb_id, doc_id, file_path, metadata, callback_url=self._callback_url, + ) + elif task_type == TaskType.DOC_REPARSE: + if not file_path: + raise RuntimeError('file_path is required for reparse task') + task_resp = self._call_parser_client( + self._parser_client.add_doc, + task_id, algo_id, kb_id, doc_id, file_path, metadata, + reparse_group=reparse_group or 'all', callback_url=self._callback_url, + ) + elif task_type == TaskType.DOC_UPDATE_META: + task_resp = self._call_parser_client( + self._parser_client.update_meta, + task_id, algo_id, kb_id, doc_id, metadata, file_path, callback_url=self._callback_url, + ) + elif task_type == TaskType.DOC_DELETE: + task_resp = self._call_parser_client( + self._parser_client.delete_doc, + task_id, algo_id, kb_id, doc_id, callback_url=self._callback_url, + ) + else: + raise RuntimeError(f'unsupported task type: {task_type.value}') if task_resp.code != 200: raise RuntimeError(f'failed to enqueue parser task: {task_resp.msg}') return task_id @@ -614,10 +783,22 @@ def _create_parser_task(self, doc_id: str, kb_id: str, algo_id: str, task_type: def _enqueue_task( self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, idempotency_key: Optional[str] = None, priority: int = 0, + file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + reparse_group: Optional[str] = None, ): - task_id = self._create_parser_task(doc_id, kb_id, algo_id, task_type) + task_id = str(uuid4()) + task_message = { + 'doc_id': doc_id, + 'kb_id': kb_id, + 'algo_id': algo_id, + 'file_path': file_path, + 'metadata': metadata, + 'reparse_group': reparse_group, + } + task_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING + self._create_task_record(task_id, task_type, doc_id, kb_id, algo_id, task_status, message=task_message) parse_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING - snapshot = self._upsert_parse_snapshot( + self._upsert_parse_snapshot( doc_id=doc_id, kb_id=kb_id, algo_id=algo_id, @@ -633,37 +814,122 @@ def _enqueue_task( error_msg=None, failed_stage=None, ) - return task_id, snapshot + try: + self._create_parser_task( + task_id, doc_id, kb_id, algo_id, task_type, + file_path=file_path, metadata=metadata, reparse_group=reparse_group, + ) + except Exception as exc: + finished_at = now_ts() + error_msg = str(exc) + self._update_task_record( + task_id, + status=DocStatus.FAILED.value, + error_code='PARSER_SUBMIT_FAILED', + error_msg=error_msg, + finished_at=finished_at, + ) + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=DocStatus.FAILED, + **self._build_snapshot_update( + self._get_parse_snapshot(doc_id, kb_id, algo_id), + task_type=task_type, + current_task_id=task_id, + error_code='PARSER_SUBMIT_FAILED', + error_msg=error_msg, + failed_stage='SUBMIT', + finished_at=finished_at, + ), + ) + self._apply_doc_upload_status(doc_id, task_type, DocStatus.FAILED) + raise + return task_id, self._get_parse_snapshot(doc_id, kb_id, algo_id) - def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: - self._validate_kb_algorithm(request.kb_id, request.algo_id) - items: List[Dict[str, Any]] = [] + def _apply_doc_upload_status(self, doc_id: str, task_type: TaskType, status: DocStatus): + if task_type == TaskType.DOC_ADD: + self._set_doc_upload_status(doc_id, status) + return + if task_type == TaskType.DOC_DELETE: + if status == DocStatus.DELETING: + if self._doc_relation_count(doc_id) <= 1: + self._set_doc_upload_status(doc_id, DocStatus.DELETING) + return + if status == DocStatus.DELETED: + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + return + if status in (DocStatus.FAILED, DocStatus.CANCELED): + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + return + + def _prepare_upload_items(self, request: UploadRequest) -> List[Dict[str, Any]]: + prepared_items: List[Dict[str, Any]] = [] for item in request.items: file_path = item.file_path if not os.path.exists(file_path): raise DocServiceError('E_INVALID_PARAM', f'file not found: {file_path}') - doc_id = gen_doc_id(file_path, doc_id=item.doc_id) - if self._has_kb_document(request.kb_id, doc_id): - self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'upload') + prepared_items.append({ + 'file_path': file_path, + 'metadata': item.metadata, + 'doc_id': gen_doc_id(file_path, doc_id=item.doc_id), + 'filename': os.path.basename(file_path), + }) + + self._validate_unique_doc_ids([item['doc_id'] for item in prepared_items]) + + for item in prepared_items: + if self._has_kb_document(request.kb_id, item['doc_id']): + self._assert_action_allowed(item['doc_id'], request.kb_id, request.algo_id, 'upload') + return prepared_items + + def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + prepared_items = self._prepare_upload_items(request) + items: List[Dict[str, Any]] = [] + for item in prepared_items: + doc_id = item['doc_id'] + file_path = item['file_path'] + metadata = item['metadata'] doc = self._upsert_doc( doc_id=doc_id, - filename=os.path.basename(file_path), + filename=item['filename'], path=file_path, - metadata=item.metadata, + metadata=metadata, source_type=request.source_type, + upload_status=DocStatus.WAITING, ) self._ensure_kb_document(request.kb_id, doc_id) - task_id, snapshot = self._enqueue_task( - doc_id, request.kb_id, request.algo_id, TaskType.DOC_ADD, - idempotency_key=request.idempotency_key, - ) + try: + task_id, snapshot = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_ADD, + idempotency_key=request.idempotency_key, + file_path=file_path, + metadata=metadata, + ) + error_code = None + error_msg = None + accepted = True + except Exception as exc: + snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) or {} + doc = self._get_doc(doc_id) or doc + task_id = snapshot.get('current_task_id') + error_code = snapshot.get('last_error_code') or type(exc).__name__ + error_msg = snapshot.get('last_error_msg') or str(exc) + accepted = False items.append({ 'doc_id': doc_id, 'kb_id': request.kb_id, 'algo_id': request.algo_id, 'upload_status': doc['upload_status'], - 'parse_status': snapshot['status'], + 'parse_status': snapshot.get('status', DocStatus.FAILED.value), 'task_id': task_id, + 'accepted': accepted, + 'error_code': error_code, + 'error_msg': error_msg, }) return items @@ -678,20 +944,26 @@ def add_files(self, request: AddRequest) -> List[Dict[str, Any]]: def reparse(self, request: ReparseRequest) -> List[str]: self._validate_kb_algorithm(request.kb_id, request.algo_id) + self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') task_ids = [] for doc_id in request.doc_ids: - if self._get_doc(doc_id) is None or not self._has_kb_document(request.kb_id, doc_id): + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'reparse') task_id, _ = self._enqueue_task( doc_id, request.kb_id, request.algo_id, TaskType.DOC_REPARSE, idempotency_key=request.idempotency_key, + file_path=doc.get('path'), + metadata=_from_json(doc.get('meta')), + reparse_group='all', ) task_ids.append(task_id) return task_ids def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: self._validate_kb_algorithm(request.kb_id, request.algo_id) + self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') items: List[Dict[str, Any]] = [] for doc_id in request.doc_ids: doc = self._get_doc(doc_id) @@ -738,6 +1010,8 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: task_id, snapshot = self._enqueue_task( item.doc_id, item.target_kb_id, item.target_algo_id, TaskType.DOC_TRANSFER, idempotency_key=request.idempotency_key, + file_path=doc.get('path'), + metadata=_from_json(doc.get('meta')), ) items.append({ 'doc_id': item.doc_id, @@ -755,6 +1029,7 @@ def patch_metadata(self, request: MetadataPatchRequest): self._validate_kb_algorithm(request.kb_id, request.algo_id) updated = [] failed = [] + self._validate_unique_doc_ids([item.doc_id for item in request.items], field_name='doc_id') for item in request.items: doc = self._get_doc(item.doc_id) if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): @@ -772,6 +1047,8 @@ def patch_metadata(self, request: MetadataPatchRequest): task_id, _ = self._enqueue_task( item.doc_id, request.kb_id, request.algo_id, TaskType.DOC_UPDATE_META, idempotency_key=request.idempotency_key, + file_path=doc.get('path'), + metadata=merged, ) updated.append({'doc_id': item.doc_id, 'task_id': task_id}) return { @@ -782,24 +1059,54 @@ def patch_metadata(self, request: MetadataPatchRequest): } def _sync_doc_upload_status(self, doc_id: str): - with self._db_manager.get_session() as session: - Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) - Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) - row = session.query(Doc).filter(Doc.doc_id == doc_id).first() - if row is None: - return - has_rel = session.query(Rel).filter(Rel.doc_id == doc_id).first() is not None - row.upload_status = DocStatus.SUCCESS.value if has_rel else DocStatus.DELETED.value - row.updated_at = now_ts() - session.add(row) + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + + @staticmethod + def _build_snapshot_update(snapshot: Optional[Dict[str, Any]], **overrides): + snapshot = snapshot or {} + data = { + 'task_type': TaskType(snapshot['task_type']) if snapshot.get('task_type') else None, + 'current_task_id': snapshot.get('current_task_id'), + 'idempotency_key': snapshot.get('idempotency_key'), + 'priority': snapshot.get('priority', 0), + 'task_score': snapshot.get('task_score'), + 'retry_count': snapshot.get('retry_count', 0), + 'max_retry': snapshot.get('max_retry', 3), + 'lease_owner': snapshot.get('lease_owner'), + 'lease_until': snapshot.get('lease_until'), + 'error_code': snapshot.get('last_error_code'), + 'error_msg': snapshot.get('last_error_msg'), + 'failed_stage': snapshot.get('failed_stage'), + 'queued_at': snapshot.get('queued_at'), + 'started_at': snapshot.get('started_at'), + 'finished_at': snapshot.get('finished_at'), + } + data.update(overrides) + return data + + def _resolve_callback_task(self, callback: TaskCallbackRequest): + task = self._get_task_record(callback.task_id) + if task is not None: + return task + payload = callback.payload or {} + required_fields = {'task_type', 'doc_id', 'kb_id', 'algo_id'} + if required_fields.issubset(payload.keys()): + return { + 'task_id': callback.task_id, + 'task_type': payload['task_type'], + 'doc_id': payload['doc_id'], + 'kb_id': payload['kb_id'], + 'algo_id': payload['algo_id'], + } + return None def on_task_callback(self, callback: TaskCallbackRequest): if not self._record_callback(callback.callback_id, callback.task_id): return {'ack': True, 'deduped': True, 'ignored_reason': None} - task = self._parser_client.get_task(callback.task_id) - if task.code != 200: + task_data = self._resolve_callback_task(callback) + if task_data is None: return {'ack': True, 'ignored_reason': 'task_not_found'} - task_data = task.data doc_id = task_data['doc_id'] kb_id = task_data['kb_id'] algo_id = task_data['algo_id'] @@ -809,40 +1116,71 @@ def on_task_callback(self, callback: TaskCallbackRequest): return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} if callback.event_type == CallbackEventType.START: + self._update_task_record( + callback.task_id, + status=DocStatus.WORKING.value, + started_at=now_ts(), + finished_at=None, + error_code=None, + error_msg=None, + ) + start_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WORKING self._upsert_parse_snapshot( doc_id=doc_id, kb_id=kb_id, algo_id=algo_id, - status=DocStatus.WORKING, - task_type=task_type, - current_task_id=callback.task_id, - started_at=now_ts(), - queued_at=None, - finished_at=None, + status=start_status, + **self._build_snapshot_update( + snapshot, + task_type=task_type, + current_task_id=callback.task_id, + started_at=now_ts(), + finished_at=None, + error_code=None, + error_msg=None, + failed_stage=None, + ), ) + if task_type == TaskType.DOC_ADD: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) + elif task_type == TaskType.DOC_DELETE: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) return {'ack': True, 'deduped': False, 'ignored_reason': None} final_status = callback.status + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.SUCCESS: + final_status = DocStatus.DELETED failed_stage = None if final_status == DocStatus.FAILED: failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' + self._update_task_record( + callback.task_id, + status=final_status.value, + error_code=callback.error_code, + error_msg=callback.error_msg, + finished_at=now_ts(), + ) + self._upsert_parse_snapshot( doc_id=doc_id, kb_id=kb_id, algo_id=algo_id, status=final_status, - task_type=task_type, - current_task_id=callback.task_id, - error_code=callback.error_code, - error_msg=callback.error_msg, - failed_stage=failed_stage, - finished_at=now_ts(), + **self._build_snapshot_update( + snapshot, + task_type=task_type, + current_task_id=callback.task_id, + error_code=callback.error_code, + error_msg=callback.error_msg, + failed_stage=failed_stage, + finished_at=now_ts(), + ), ) if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: self._remove_kb_document(kb_id, doc_id) - self._sync_doc_upload_status(doc_id) + self._apply_doc_upload_status(doc_id, task_type, final_status) return {'ack': True, 'deduped': False, 'ignored_reason': None} @@ -904,9 +1242,7 @@ def get_doc_detail(self, doc_id: str): snapshot = matched_items[0].get('snapshot') if matched_items else None latest_task = None if snapshot and snapshot.get('current_task_id'): - latest_task_resp = self._parser_client.get_task(snapshot['current_task_id']) - if latest_task_resp.code == 200: - latest_task = latest_task_resp.data + latest_task = self._get_task_record(snapshot['current_task_id']) return { 'doc': doc, 'relation': relation, @@ -917,21 +1253,88 @@ def get_doc_detail(self, doc_id: str): } def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): - return self._parser_client.list_tasks(status=status, page=page, page_size=page_size) + parser_list_tasks = getattr(self._parser_client, 'list_tasks', None) + if callable(parser_list_tasks): + try: + return parser_list_tasks(status=status, page=page, page_size=page_size) + except Exception: + pass + page = max(page, 1) + page_size = max(page_size, 1) + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + query = session.query(Task) + if status: + query = query.filter(Task.status.in_(status)) + total = query.count() + rows = ( + query.order_by(Task.created_at.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + items = [] + for row in rows: + task = _orm_to_dict(row) + task['message'] = _from_json(task.get('message')) + items.append(task) + return BaseResponse(code=200, msg='success', data={ + 'items': items, + 'total': total, + 'page': page, + 'page_size': page_size, + }) def get_task(self, task_id: str): - return self._parser_client.get_task(task_id) + parser_get_task = getattr(self._parser_client, 'get_task', None) + if callable(parser_get_task): + try: + return parser_get_task(task_id) + except Exception: + pass + task = self._get_task_record(task_id) + if task is None: + return BaseResponse(code=404, msg='task not found', data=None) + return BaseResponse(code=200, msg='success', data=task) def get_tasks_batch(self, task_ids: List[str]): items = [] for task_id in task_ids: - resp = self._parser_client.get_task(task_id) + resp = self.get_task(task_id) if resp.code == 200 and resp.data is not None: items.append(resp.data) return {'items': items} def cancel_task(self, task_id: str): - return self._parser_client.cancel_task(task_id) + task = self._get_task_record(task_id) + if task is None: + return BaseResponse(code=404, msg='task not found', data={'task_id': task_id, 'cancel_status': False}) + if task.get('status') != DocStatus.WAITING.value: + return BaseResponse( + code=409, + msg='task cannot be canceled', + data={'task_id': task_id, 'cancel_status': False, 'status': task.get('status')}, + ) + resp = self._parser_client.cancel_task(task_id) + if resp.code != 200: + return resp + resp_data = resp.data or {} + if not resp_data.get('cancel_status'): + return BaseResponse( + code=409, + msg=resp_data.get('message', 'task cannot be canceled'), + data={'task_id': task_id, 'cancel_status': False, 'status': task.get('status')}, + ) + self.on_task_callback(TaskCallbackRequest( + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.CANCELED, + )) + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'cancel_status': True, 'status': DocStatus.CANCELED.value}, + ) def list_algorithms(self): resp = self._parser_client.list_algorithms() @@ -970,39 +1373,131 @@ def health(self): 'deps': {'sql': True}, } - def list_kbs(self): + def list_kbs( + self, + page: int = 1, + page_size: int = 20, + keyword: Optional[str] = None, + status: Optional[List[str]] = None, + owner_id: Optional[str] = None, + ): + page = max(page, 1) + page_size = max(page_size, 1) with self._db_manager.get_session() as session: Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) - rows = session.query(Kb).order_by(Kb.updated_at.desc()).all() - items = [] - for row in rows: - items.append({ - 'kb_id': row.kb_id, - 'display_name': row.display_name, - 'description': row.description, - 'doc_count': row.doc_count, - 'status': row.status, - 'owner_id': row.owner_id, - 'meta': _from_json(row.meta), - 'created_at': row.created_at, - 'updated_at': row.updated_at, - }) - return {'items': items} + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + query = session.query(Kb, Rel).outerjoin(Rel, Rel.kb_id == Kb.kb_id) + if keyword: + like_expr = f'%{keyword}%' + query = query.filter( + sqlalchemy.or_(Kb.kb_id.like(like_expr), Kb.display_name.like(like_expr), Kb.description.like(like_expr)) + ) + if status: + query = query.filter(Kb.status.in_(status)) + if owner_id: + query = query.filter(Kb.owner_id == owner_id) + total = query.count() + rows = ( + query.order_by(Kb.updated_at.desc(), Kb.created_at.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + items = [self._build_kb_data(kb_row, algo_row) for kb_row, algo_row in rows] + return {'items': items, 'total': total, 'page': page, 'page_size': page_size} + + def get_kb(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = ( + session.query(Kb, Rel) + .outerjoin(Rel, Rel.kb_id == Kb.kb_id) + .filter(Kb.kb_id == kb_id) + .first() + ) + if row is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + kb_row, algo_row = row + return self._build_kb_data(kb_row, algo_row) + + def batch_get_kbs(self, kb_ids: List[str]): + if not kb_ids: + raise DocServiceError('E_INVALID_PARAM', 'kb_ids is required', {'kb_ids': kb_ids}) + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + rows = ( + session.query(Kb, Rel) + .outerjoin(Rel, Rel.kb_id == Kb.kb_id) + .filter(Kb.kb_id.in_(kb_ids)) + .all() + ) + row_map = {kb_row.kb_id: self._build_kb_data(kb_row, algo_row) for kb_row, algo_row in rows} + items = [] + missing_kb_ids = [] + for kb_id in kb_ids: + if kb_id in row_map: + items.append(row_map[kb_id]) + else: + missing_kb_ids.append(kb_id) + return {'items': items, 'missing_kb_ids': missing_kb_ids} def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, algo_id: str = '__default__'): if not kb_id: raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + self._ensure_algorithm_exists(algo_id) binding = self._get_kb_algorithm(kb_id) if binding is not None and binding['algo_id'] != algo_id: raise DocServiceError( 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {binding["algo_id"]}', {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} ) - self._ensure_kb(kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta) + update_fields = set() + if display_name is not None: + update_fields.add('display_name') + if description is not None: + update_fields.add('description') + if owner_id is not None: + update_fields.add('owner_id') + if meta is not None: + update_fields.add('meta') + self._ensure_kb( + kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta, + update_fields=update_fields, + ) self._ensure_kb_algorithm(kb_id, algo_id) - return {'kb_id': kb_id, 'status': KBStatus.ACTIVE.value} + return self.get_kb(kb_id) + + def update_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: Optional[str] = None, explicit_fields: Optional[Set[str]] = None): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + explicit_fields = explicit_fields or set() + if 'algo_id' in explicit_fields: + if algo_id is None: + raise DocServiceError('E_INVALID_PARAM', 'algo_id cannot be null', {'kb_id': kb_id}) + self._ensure_algorithm_exists(algo_id) + if algo_id is not None: + binding = self._get_kb_algorithm(kb_id) + if binding is None: + self._ensure_kb_algorithm(kb_id, algo_id) + elif binding['algo_id'] != algo_id: + raise DocServiceError( + 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {binding["algo_id"]}', + {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} + ) + self._ensure_kb( + kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta, + update_fields=explicit_fields & {'display_name', 'description', 'owner_id', 'meta'}, + ) + return self.get_kb(kb_id) def delete_kb(self, kb_id: str): if not kb_id: diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 696fe3d0e..827a24957 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -10,12 +10,15 @@ from lazyllm.thirdparty import fastapi from ..utils import BaseResponse, _get_default_db_config, ensure_call_endpoint -from .base import AddRequest, DeleteRequest, DocServiceError, MetadataPatchRequest, ReparseRequest -from .base import SourceType, TaskCallbackRequest +from .base import ( + AddRequest, DeleteRequest, DocServiceError, KbBatchQueryRequest, KbCreateRequest, KbUpdateRequest, + MetadataPatchRequest, ReparseRequest, +) +from .base import CallbackEventType, DocStatus, SourceType, TaskCallbackRequest from .base import TransferRequest from .base import UploadRequest, AddFileItem from .doc_manager import DocManager -from .parsing_server import ParsingTaskServer +from ..parsing_service.base import TaskStatus, TaskType class DocServer(ModuleBase): @@ -41,26 +44,16 @@ def __init__( @once_wrapper(reset_on_pickle=True) def _lazy_init(self): os.makedirs(self._storage_dir, exist_ok=True) - if self._parser_url: - self._manager = DocManager( - db_config=self._db_config, - parser_url=self._parser_url, - callback_url=self._callback_url, - ) - else: - parser_db = self._parser_db_config or _get_default_db_config('doc_service_parser') - self._parser = ParsingTaskServer(db_config=parser_db, poll_interval=self._parser_poll_interval) - self._parser.start() - self._manager = DocManager( - db_config=self._db_config, - parser_server=self._parser, - callback_url=self._callback_url, - ) + if not self._parser_url: + raise ValueError('parser_url is required; doc_service no longer starts a mock parsing server') + self._manager = DocManager( + db_config=self._db_config, + parser_url=self._parser_url, + callback_url=self._callback_url, + ) def stop(self): - self._lazy_init() - if self._parser: - self._parser.stop() + return None def set_runtime_callback_url(self, callback_url: str): self._lazy_init() @@ -115,6 +108,13 @@ def _build_upload_payload(request: UploadRequest, file_identities: Optional[List 'items': items, } + @staticmethod + def _build_update_kb_payload(kb_id: str, request: KbUpdateRequest): + payload = request.model_dump(mode='json', exclude_unset=True) + payload['kb_id'] = kb_id + payload['explicit_fields'] = sorted(request.model_fields_set) + return payload + def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = None): safe_name = os.path.basename(filename) or 'upload.bin' file_path = os.path.join(self._storage_dir, safe_name) @@ -137,6 +137,65 @@ def _run_upload(self, request: UploadRequest, payload: Optional[Dict[str, Any]] lambda: {'items': self._manager.upload(request)} )) + @staticmethod + def _normalize_task_callback(callback: Any) -> TaskCallbackRequest: + if isinstance(callback, TaskCallbackRequest): + return callback + if not isinstance(callback, dict): + raise DocServiceError('E_INVALID_PARAM', 'invalid callback payload') + + payload = dict(callback.get('payload') or {}) + for field in ('task_type', 'doc_id', 'kb_id', 'algo_id'): + if callback.get(field) is not None and field not in payload: + payload[field] = callback[field] + + event_type = callback.get('event_type') + status = callback.get('status') + task_status = callback.get('task_status') + + try: + if status is not None: + normalized_status = DocStatus(status) + normalized_event_type = CallbackEventType(event_type) if event_type else ( + CallbackEventType.START + if normalized_status in (DocStatus.WAITING, DocStatus.WORKING) else CallbackEventType.FINISH + ) + elif task_status is not None: + normalized_status = DocStatus(task_status) + normalized_event_type = CallbackEventType(event_type) if event_type else ( + CallbackEventType.START + if normalized_status in (DocStatus.WAITING, DocStatus.WORKING) + else CallbackEventType.FINISH + ) + else: + raise DocServiceError('E_INVALID_PARAM', 'status or task_status is required') + except ValueError as exc: + raise DocServiceError('E_INVALID_PARAM', str(exc)) from exc + + callback_data = { + 'callback_id': callback.get('callback_id'), + 'task_id': callback.get('task_id'), + 'event_type': normalized_event_type, + 'status': normalized_status, + 'error_code': callback.get('error_code'), + 'error_msg': callback.get('error_msg'), + 'payload': payload, + } + return TaskCallbackRequest.model_validate({k: v for k, v in callback_data.items() if v is not None}) + + @staticmethod + def _format_task_view(task: Optional[Dict[str, Any]]): + if not isinstance(task, dict): + return task + return dict(task) + + def _format_task_response_data(self, data: Any): + if isinstance(data, dict) and isinstance(data.get('items'), list): + payload = dict(data) + payload['items'] = [self._format_task_view(item) for item in data['items']] + return payload + return self._format_task_view(data) + def upload_request(self, request: UploadRequest): self._lazy_init() return self._run_upload(request) @@ -274,13 +333,23 @@ def patch_metadata(self, request: MetadataPatchRequest): def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): self._lazy_init() resp = self._manager.list_tasks(status, page, page_size) - return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) @app.get('/v1/tasks/{task_id}') def get_task(self, task_id: str): self._lazy_init() resp = self._manager.get_task(task_id) - return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) def cancel_task_by_id(self, task_id: str): self._lazy_init() @@ -308,10 +377,15 @@ def _cancel(): '/v1/tasks/cancel', idempotency_key, payload, _cancel )) + def task_callback(self, callback: Any): + self._lazy_init() + return self._run(lambda: self._manager.on_task_callback(self._normalize_task_callback(callback))) + @app.post('/v1/internal/callbacks/tasks') - def task_callback(self, callback: TaskCallbackRequest): + async def task_callback_http(self, request: 'fastapi.Request'): self._lazy_init() - return self._run(lambda: self._manager.on_task_callback(callback)) + payload = await request.json() + return self.task_callback(payload) @app.get('/v1/algo/list') def list_algo(self): @@ -371,17 +445,45 @@ async def get_task_info(self, request: 'fastapi.Request'): return self._response(data={'biz_code': 'E_INVALID_PARAM'}, code=400, msg='task_id is required', status_code=400) resp = self._manager.get_task(task_id) - return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) def get_task_info_impl(self, task_id: str): self._lazy_init() resp = self._manager.get_task(task_id) - return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) @app.get('/v1/kbs') - def list_kbs(self): + def list_kbs( + self, + page: int = 1, + page_size: int = 20, + keyword: Optional[str] = None, + status: Optional[List[str]] = None, + owner_id: Optional[str] = None, + ): + self._lazy_init() + return self._run(lambda: self._manager.list_kbs( + page=page, + page_size=page_size, + keyword=keyword, + status=status, + owner_id=owner_id, + )) + + @app.get('/v1/kbs/{kb_id}') + def get_kb(self, kb_id: str): self._lazy_init() - return self._run(lambda: self._manager.list_kbs()) + return self._run(lambda: self._manager.get_kb(kb_id)) def create_kb_by_id(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, @@ -397,21 +499,46 @@ def create_kb_by_id(self, kb_id: str, display_name: Optional[str] = None, descri )) @app.post('/v1/kbs') - async def create_kb(self, request: 'fastapi.Request'): - payload = await request.json() - idempotency_key = payload.get('idempotency_key') + def create_kb(self, request: KbCreateRequest): + self._lazy_init() + payload = request.model_dump(mode='json') return self._run(lambda: self._manager.run_idempotent( - '/v1/kbs', idempotency_key, payload, + '/v1/kbs', request.idempotency_key, payload, lambda: self._manager.create_kb( - payload.get('kb_id'), - display_name=payload.get('display_name'), - description=payload.get('description'), - owner_id=payload.get('owner_id'), - meta=payload.get('meta'), - algo_id=payload.get('algo_id', '__default__'), + request.kb_id, + display_name=request.display_name, + description=request.description, + owner_id=request.owner_id, + meta=request.meta, + algo_id=request.algo_id, + ) + )) + + def update_kb_by_id(self, kb_id: str, request: KbUpdateRequest): + self._lazy_init() + payload = self._build_update_kb_payload(kb_id, request) + return self._run(lambda: self._manager.run_idempotent( + f'/v1/kbs/{kb_id}:patch', request.idempotency_key, payload, + lambda: self._manager.update_kb( + kb_id, + display_name=request.display_name, + description=request.description, + owner_id=request.owner_id, + meta=request.meta, + algo_id=request.algo_id, + explicit_fields=set(request.model_fields_set), ) )) + @app.post('/v1/kbs/{kb_id}/update') + def update_kb(self, kb_id: str, request: KbUpdateRequest): + return self.update_kb_by_id(kb_id, request) + + @app.post('/v1/kbs/batch') + def batch_get_kbs(self, request: KbBatchQueryRequest): + self._lazy_init() + return self._run(lambda: self._manager.batch_get_kbs(request.kb_ids)) + @app.delete('/v1/kbs/{kb_id}') def delete_kb(self, kb_id: str, idempotency_key: Optional[str] = None): self._lazy_init() @@ -465,6 +592,8 @@ def __init__( if url: self._impl = UrlModule(url=ensure_call_endpoint(url)) else: + if not parser_url: + raise ValueError('parser_url is required; doc_service no longer embeds a mock parsing server') self._raw_impl = DocServer._Impl( storage_dir=self._storage_dir, db_config=self._db_config, @@ -553,8 +682,11 @@ def get_task(self, task_id: str): def cancel_task(self, task_id: str): return self._dispatch('cancel_task_by_id', task_id) - def list_kbs(self): - return self._dispatch('list_kbs') + def list_kbs(self, **kwargs): + return self._dispatch('list_kbs', **kwargs) + + def get_kb(self, kb_id: str): + return self._dispatch('get_kb', kb_id) def list_chunks(self, page: int = 1, page_size: int = 20): return self._dispatch('list_chunks', page, page_size) @@ -570,6 +702,12 @@ def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: algo_id: str = '__default__'): return self._dispatch('create_kb_by_id', kb_id, display_name, description, owner_id, meta, algo_id) + def update_kb(self, kb_id: str, request: KbUpdateRequest): + return self._dispatch('update_kb_by_id', kb_id, request) + + def batch_get_kbs(self, kb_ids: List[str]): + return self._dispatch('batch_get_kbs', KbBatchQueryRequest(kb_ids=kb_ids)) + def delete_kb(self, kb_id: str): return self._dispatch('delete_kb', kb_id) diff --git a/lazyllm/tools/rag/doc_service/parsing_server.py b/lazyllm/tools/rag/doc_service/parsing_server.py deleted file mode 100644 index d5cca0766..000000000 --- a/lazyllm/tools/rag/doc_service/parsing_server.py +++ /dev/null @@ -1,434 +0,0 @@ -from __future__ import annotations -''' -Mock parsing execution service used in phase-1 refactor validation. - -Note: -This is intentionally isolated from `lazyllm.tools.rag.parsing_service` so that -DocService API contract and state machine can be validated without requiring the -full parser runtime and algorithm registry. -''' - -import json -import threading -import time -import traceback -from typing import Any, Callable, Dict, List, Optional -from datetime import datetime - -import cloudpickle -import requests - -from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, UrlModule, once_wrapper -from lazyllm.thirdparty import fastapi - -from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict, ensure_call_endpoint -from ...sql import SqlManager -from ..parsing_service.base import ALGORITHM_TABLE_INFO -from .base import ( - CallbackEventType, - TaskCallbackRequest, - TaskCreateRequest, - DocStatus, - now_ts, -) - -PARSER_TASK_TABLE_INFO = { - 'name': 'lazyllm_parse_tasks', - 'comment': 'Parse task table for mock parser service', - 'columns': [ - {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, - 'comment': 'Auto increment ID'}, - {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, - {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type'}, - {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, - {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, - {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, - {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Task status'}, - {'name': 'priority', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Task priority'}, - {'name': 'message', 'data_type': 'text', 'nullable': False, 'comment': 'Task payload'}, - {'name': 'callback_url', 'data_type': 'string', 'nullable': True, 'comment': 'Callback URL'}, - {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Error code'}, - {'name': 'error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Error message'}, - {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, - 'comment': 'Created time'}, - {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, - 'comment': 'Updated time'}, - {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, - {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, - ], -} - - -class ParsingTaskServer(ModuleBase): - class _Impl: - def __init__( - self, - db_config: Optional[Dict[str, Any]] = None, - poll_interval: float = 0.05, - callback_func: Optional[Callable[[TaskCallbackRequest], None]] = None, - ): - self._db_config = db_config or _get_default_db_config('doc_service_parser') - self._poll_interval = poll_interval - self._db_manager = None - self._task_thread = None - self._shutdown = False - self._callback_func = callback_func - - @once_wrapper(reset_on_pickle=True) - def _lazy_init(self): - self._db_manager = SqlManager( - **self._db_config, - tables_info_dict={'tables': [PARSER_TASK_TABLE_INFO, ALGORITHM_TABLE_INFO]}, - ) - self._ensure_indexes() - self._upsert_default_algorithm() - self._shutdown = False - self._task_thread = threading.Thread(target=self._task_worker, daemon=True) - self._task_thread.start() - - def stop(self): - self._shutdown = True - if self._task_thread and self._task_thread.is_alive(): - self._task_thread.join(timeout=2) - - def _ensure_indexes(self): - stmts = [ - 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_tasks_task_id ON lazyllm_parse_tasks(task_id)', - 'CREATE INDEX IF NOT EXISTS idx_parse_tasks_status ON lazyllm_parse_tasks(status, updated_at)', - 'CREATE INDEX IF NOT EXISTS idx_parse_tasks_doc ON lazyllm_parse_tasks(doc_id, kb_id, algo_id)', - ] - for stmt in stmts: - self._db_manager.execute_commit(stmt) - - def _upsert_default_algorithm(self): - default_group_info = [ - {'name': 'CoarseChunk', 'type': 'chunk', 'display_name': 'Coarse Chunk'}, - {'name': 'line', 'type': 'chunk', 'display_name': 'Line Chunk'}, - ] - default_info = { - 'store': None, - 'reader': None, - 'node_groups': { - item['name']: {'group_type': item['type'], 'display_name': item['display_name']} - for item in default_group_info - }, - 'schema_extractor': None, - } - with self._db_manager.get_session() as session: - Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) - row = session.query(Algo).filter(Algo.id == '__default__').first() - if row is None: - session.add( - Algo( - id='__default__', - display_name='Default', - description='Default mock parsing algorithm', - info_pickle=cloudpickle.dumps(default_info), - created_at=now_ts(), - updated_at=now_ts(), - ) - ) - - def register_callback(self, callback_func: Callable[[TaskCallbackRequest], None]): - self._callback_func = callback_func - - def _load_task(self, task_id: str): - with self._db_manager.get_session() as session: - Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) - row = session.query(Task).filter(Task.task_id == task_id).first() - return _orm_to_dict(row) if row else None - - def _emit_callback(self, callback_payload: TaskCallbackRequest, callback_url: Optional[str]): - if self._callback_func: - self._callback_func(callback_payload) - return - if callback_url: - response = requests.post(callback_url, json=callback_payload.model_dump(), timeout=5) - if response.status_code >= 400: - raise RuntimeError(f'callback failed: {response.status_code} {response.text}') - - def _update_task(self, task_id: str, **fields): - with self._db_manager.get_session() as session: - Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) - row = session.query(Task).filter(Task.task_id == task_id).first() - if row is None: - return None - for key, value in fields.items(): - setattr(row, key, value) - row.updated_at = now_ts() - session.add(row) - return _orm_to_dict(row) - - def _task_worker(self): - while not self._shutdown: - try: - waiting_task = None - with self._db_manager.get_session() as session: - Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) - waiting_task = ( - session.query(Task) - .filter(Task.status == DocStatus.WAITING.value) - .order_by(Task.priority.desc(), Task.created_at.asc()) - .first() - ) - if waiting_task is not None: - waiting_task.status = DocStatus.WORKING.value - waiting_task.started_at = now_ts() - waiting_task.updated_at = now_ts() - session.add(waiting_task) - waiting_task = _orm_to_dict(waiting_task) - if waiting_task is None: - time.sleep(self._poll_interval) - continue - - start_callback = TaskCallbackRequest( - task_id=waiting_task['task_id'], - event_type=CallbackEventType.START, - status=DocStatus.WORKING, - payload={ - 'task_type': waiting_task['task_type'], - 'doc_id': waiting_task['doc_id'], - 'kb_id': waiting_task['kb_id'], - 'algo_id': waiting_task['algo_id'], - }, - ) - self._emit_callback(start_callback, waiting_task.get('callback_url')) - - # Mock workload - time.sleep(self._poll_interval) - final_status = ( - DocStatus.DELETED.value - if waiting_task['task_type'] == 'DOC_DELETE' - else DocStatus.SUCCESS.value - ) - done = self._update_task( - waiting_task['task_id'], - status=final_status, - finished_at=now_ts(), - error_code=None, - error_msg=None, - ) - finish_callback = TaskCallbackRequest( - task_id=waiting_task['task_id'], - event_type=CallbackEventType.FINISH, - status=DocStatus(final_status), - error_code=None, - error_msg=None, - payload={ - 'task_type': waiting_task['task_type'], - 'doc_id': waiting_task['doc_id'], - 'kb_id': waiting_task['kb_id'], - 'algo_id': waiting_task['algo_id'], - 'result': done, - }, - ) - self._emit_callback(finish_callback, waiting_task.get('callback_url')) - except Exception as exc: - LOG.error(f'[ParsingTaskServer] worker loop error: {exc}, {traceback.format_exc()}') - time.sleep(self._poll_interval) - - @app.post('/v1/internal/tasks/create') - def create_task(self, request: TaskCreateRequest): - self._lazy_init() - now = now_ts() - payload = request.model_dump(mode='json') - with self._db_manager.get_session() as session: - Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) - exists = session.query(Task).filter(Task.task_id == request.task_id).first() - if exists is not None: - return BaseResponse(code=200, msg='success', data={'task': _orm_to_dict(exists), 'deduped': True}) - session.add( - Task( - task_id=request.task_id, - task_type=request.task_type.value, - doc_id=request.doc_id, - kb_id=request.kb_id, - algo_id=request.algo_id, - status=DocStatus.WAITING.value, - priority=request.priority, - message=json.dumps(payload, ensure_ascii=False), - callback_url=request.callback_url, - error_code=None, - error_msg=None, - created_at=now, - updated_at=now, - started_at=None, - finished_at=None, - ) - ) - task = self._load_task(request.task_id) - return BaseResponse(code=200, msg='success', data={'task': task, 'deduped': False}) - - @app.post('/v1/internal/tasks/cancel') - def cancel_task(self, request: Dict[str, str]): - self._lazy_init() - task_id = request.get('task_id') - if not task_id: - raise fastapi.HTTPException(status_code=400, detail='task_id is required') - task = self._load_task(task_id) - if not task: - return BaseResponse(code=404, msg='task not found', data={'task_id': task_id, 'cancel_status': False}) - if task['status'] != DocStatus.WAITING.value: - return BaseResponse( - code=409, - msg='task cannot be canceled', - data={'task_id': task_id, 'cancel_status': False, 'status': task['status']}, - ) - task = self._update_task(task_id, status=DocStatus.CANCELED.value, finished_at=now_ts()) - callback = TaskCallbackRequest( - task_id=task_id, - event_type=CallbackEventType.FINISH, - status=DocStatus.CANCELED, - payload={ - 'task_type': task.get('task_type'), - 'doc_id': task.get('doc_id'), - 'kb_id': task.get('kb_id'), - 'algo_id': task.get('algo_id'), - }, - ) - try: - self._emit_callback(callback, task.get('callback_url')) - except Exception as exc: - LOG.warning(f'[ParsingTaskServer] cancel callback failed: {exc}') - return BaseResponse( - code=200, - msg='success', - data={'task_id': task_id, 'cancel_status': True, 'status': DocStatus.CANCELED.value}, - ) - - @app.get('/v1/tasks') - def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): - self._lazy_init() - with self._db_manager.get_session() as session: - Task = self._db_manager.get_table_orm_class(PARSER_TASK_TABLE_INFO['name']) - query = session.query(Task) - if status: - query = query.filter(Task.status.in_(status)) - total = query.count() - rows = ( - query.order_by(Task.created_at.desc()) - .offset(max(page - 1, 0) * page_size) - .limit(page_size) - .all() - ) - items = [_orm_to_dict(row) for row in rows] - return BaseResponse(code=200, msg='success', data={ - 'items': items, - 'total': total, - 'page': page, - 'page_size': page_size, - }) - - @app.get('/v1/tasks/{task_id}') - def get_task(self, task_id: str): - self._lazy_init() - task = self._load_task(task_id) - if task is None: - raise fastapi.HTTPException(status_code=404, detail='task not found') - return BaseResponse(code=200, msg='success', data=task) - - @app.get('/v1/algo/list') - def list_algorithms(self): - self._lazy_init() - with self._db_manager.get_session() as session: - Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) - rows = session.query(Algo).order_by(Algo.created_at.asc()).all() - data = [ - {'algo_id': row.id, 'display_name': row.display_name, 'description': row.description} - for row in rows - ] - return BaseResponse(code=200, msg='success', data=data) - - @app.get('/v1/algo/{algo_id}/groups') - def get_algorithm_groups(self, algo_id: str): - self._lazy_init() - with self._db_manager.get_session() as session: - Algo = self._db_manager.get_table_orm_class(ALGORITHM_TABLE_INFO['name']) - row = session.query(Algo).filter(Algo.id == algo_id).first() - if row is None: - raise fastapi.HTTPException(status_code=404, detail='algo not found') - info = cloudpickle.loads(row.info_pickle) - node_groups = info.get('node_groups', {}) if isinstance(info, dict) else {} - data = [] - for name, group in node_groups.items(): - data.append({ - 'name': name, - 'type': group.get('group_type'), - 'display_name': group.get('display_name'), - }) - return BaseResponse(code=200, msg='success', data=data) - - @app.get('/v1/health') - def health(self): - self._lazy_init() - healthy = self._task_thread is not None and self._task_thread.is_alive() - return BaseResponse(code=200 if healthy else 503, msg='success' if healthy else 'unhealthy', data={ - 'status': 'ok' if healthy else 'degraded', - 'version': 'v1-mock', - 'deps': { - 'sql': bool(self._db_manager), - 'worker': healthy, - }, - }) - - def __init__( - self, - port: Optional[int] = None, - url: Optional[str] = None, - db_config: Optional[Dict[str, Any]] = None, - poll_interval: float = 0.05, - callback_func: Optional[Callable[[TaskCallbackRequest], None]] = None, - launcher=None, - ): - super().__init__() - self._raw_impl = None - self._db_config = db_config or _get_default_db_config('doc_service_parser') - if url: - self._impl = UrlModule(url=ensure_call_endpoint(url)) - else: - self._raw_impl = ParsingTaskServer._Impl( - db_config=self._db_config, - poll_interval=poll_interval, - callback_func=callback_func, - ) - self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) - - def start(self): - result = super().start() - if self._raw_impl: - self._dispatch('_lazy_init') - return result - - def stop(self): - if self._raw_impl: - self._dispatch('stop') - if isinstance(self._impl, ServerModule): - self._impl.stop() - - def _dispatch(self, method: str, *args, **kwargs): - impl = self._impl - if isinstance(impl, ServerModule): - return impl._call(method, *args, **kwargs) - return getattr(impl, method)(*args, **kwargs) - - def register_callback(self, callback_func: Callable[[TaskCallbackRequest], None]): - if self._raw_impl: - self._raw_impl.register_callback(callback_func) - - def create_task(self, request: TaskCreateRequest): - return self._dispatch('create_task', request) - - def cancel_task(self, task_id: str): - return self._dispatch('cancel_task', {'task_id': task_id}) - - def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): - return self._dispatch('list_tasks', status, page, page_size) - - def get_task(self, task_id: str): - return self._dispatch('get_task', task_id) - - def list_algorithms(self): - return self._dispatch('list_algorithms') - - def get_algorithm_groups(self, algo_id: str): - return self._dispatch('get_algorithm_groups', algo_id) diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 0c67a7c65..f7ff8390d 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -29,6 +29,7 @@ class AddDocRequest(BaseModel): kb_id: Optional[str] = None file_infos: List[FileInfo] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info, feedback_url) is deprecated, will be removed in the future db_info: Optional[DBInfo] = None feedback_url: Optional[str] = None @@ -40,8 +41,10 @@ class UpdateMetaRequest(BaseModel): kb_id: Optional[str] = None file_infos: List[FileInfo] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info) is deprecated, will be removed in the future db_info: Optional[DBInfo] = None + feedback_url: Optional[str] = None class DeleteDocRequest(BaseModel): @@ -50,8 +53,10 @@ class DeleteDocRequest(BaseModel): kb_id: Optional[str] = None doc_ids: List[str] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info) is deprecated, will be removed in the future db_info: Optional[DBInfo] = None + feedback_url: Optional[str] = None class CancelTaskRequest(BaseModel): @@ -61,9 +66,8 @@ class CancelTaskRequest(BaseModel): class TaskStatus(str, Enum): WAITING = 'WAITING' WORKING = 'WORKING' - CANCEL_REQUESTED = 'CANCEL_REQUESTED' CANCELED = 'CANCELED' - FINISHED = 'FINISHED' + SUCCESS = 'SUCCESS' FAILED = 'FAILED' @@ -127,9 +131,13 @@ def _calculate_task_score(task_type: str, user_priority: int) -> int: {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type: DOC_ADD, DOC_DELETE, DOC_UPDATE_META, DOC_REPARSE'}, {'name': 'task_status', 'data_type': 'string', 'nullable': False, - 'comment': 'Task status: WAITING, WORKING, CANCEL_REQUESTED, CANCELED, FINISHED, FAILED'}, + 'comment': 'Task status: WAITING, WORKING, CANCELED, SUCCESS, FAILED'}, {'name': 'finished_at', 'data_type': 'datetime', 'nullable': False, 'comment': 'Finish time (set when processing completes)', 'default': datetime.now}, + {'name': 'callback_url', 'data_type': 'string', 'nullable': True, + 'comment': 'Callback target url for built-in HTTP callback'}, + {'name': 'task_context_json', 'data_type': 'string', 'nullable': True, + 'comment': 'Serialized callback context used to build callback payload'}, {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'default': '200', 'comment': 'Error code (varchar64)'}, {'name': 'error_msg', 'data_type': 'string', 'nullable': True, 'default': 'success', diff --git a/lazyllm/tools/rag/parsing_service/queue.py b/lazyllm/tools/rag/parsing_service/queue.py index 01c9f6a49..83d241911 100644 --- a/lazyllm/tools/rag/parsing_service/queue.py +++ b/lazyllm/tools/rag/parsing_service/queue.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Any from lazyllm import LOG +import sqlalchemy from ...sql import SqlManager from ..utils import _orm_to_dict @@ -29,11 +30,42 @@ def __init__(self, table_name: str, columns: List[Dict[str, Any]], db_config: Di ] } ) + self._ensure_columns_exist() LOG.info(f'[SQLBasedQueue] Queue {self._table_name} initialized successfully') except Exception as e: LOG.error(f'[SQLBasedQueue] Failed to initialize queue {self._table_name}: {e}') raise + def _ensure_columns_exist(self): + inspector = sqlalchemy.inspect(self._sql_manager.engine) + existing_columns = {column['name'] for column in inspector.get_columns(self._table_name)} + missing_columns = [column for column in self._columns if column['name'] not in existing_columns] + if not missing_columns: + return + + for column in missing_columns: + sql_type = self._sql_manager._sql_type_for(column['data_type']) + if isinstance(sql_type, type): + sql_type = sql_type() + type_sql = sql_type.compile(dialect=self._sql_manager.engine.dialect) + nullable_sql = '' if column.get('nullable', True) else ' NOT NULL' + default_sql = '' + default_value = column.get('default') + if default_value is not None and not callable(default_value): + if isinstance(default_value, str): + escaped_default = default_value.replace("'", "''") + default_sql = f" DEFAULT '{escaped_default}'" + elif isinstance(default_value, bool): + default_sql = f' DEFAULT {int(default_value)}' + else: + default_sql = f' DEFAULT {default_value}' + statement = sqlalchemy.text( + f'ALTER TABLE "{self._table_name}" ADD COLUMN "{column["name"]}" {type_sql}{nullable_sql}{default_sql}' + ) + with self._sql_manager.engine.begin() as connection: + connection.execute(statement) + LOG.info(f'[SQLBasedQueue] Added missing column {column["name"]} to {self._table_name}') + def _build_query(self, session, filter_by: Dict[str, Any] = None): TableCls = self._sql_manager.get_table_orm_class(self._table_name) query = session.query(TableCls) @@ -142,3 +174,21 @@ def clear(self, filter_by: Dict[str, Any] = None) -> int: except Exception as e: LOG.error(f'[SQLBasedQueue] Failed to clear {self._table_name}: {e}') raise + + def update(self, filter_by: Dict[str, Any], **kwargs) -> Optional[Dict[str, Any]]: + try: + with self._sql_manager.get_session() as session: + query = self._build_query(session, filter_by) + record = query.with_for_update().first() + if not record: + return None + + for key, value in kwargs.items(): + setattr(record, key, value) + session.flush() + result = _orm_to_dict(record) + LOG.debug(f'[SQLBasedQueue] Updated record in {self._table_name}') + return result + except Exception as e: + LOG.error(f'[SQLBasedQueue] Failed to update {self._table_name}: {e}') + raise diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index 6c94156e6..4b86f48e3 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -1,20 +1,25 @@ import json +import random import threading import time import traceback import cloudpickle -from datetime import datetime -from typing import Any, Callable, Dict, Optional +from datetime import datetime, timedelta +from email.utils import parsedate_to_datetime +from uuid import NAMESPACE_URL, uuid5 +from typing import Any, Callable, Dict, Optional, Tuple from lazyllm import ( LOG, ModuleBase, ServerModule, UrlModule, FastapiApp as app, LazyLLMLaunchersBase as Launcher, once_wrapper ) +import requests from lazyllm.thirdparty import fastapi from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, - TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, _calculate_task_score + TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, + _calculate_task_score ) from .worker import DocumentProcessorWorker as Worker from .queue import _SQLBasedQueue as Queue @@ -26,19 +31,26 @@ from ..doc_to_db import SchemaExtractor from ...sql import SqlManager +CALLBACK_RETRY_MIN_INTERVAL = 5.0 +CALLBACK_RETRY_MAX_INTERVAL = 300.0 +CALLBACK_RETRY_MAX_ATTEMPTS = 5 + class DocumentProcessor(ModuleBase): class _Impl(): def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int = 1, - post_func: Optional[Callable] = None, path_prefix: Optional[str] = None): + post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, + callback_url: Optional[str] = None): self._db_config = db_config self._num_workers = num_workers self._post_func = post_func + self._callback_url = self._normalize_callback_url(callback_url) if not self._check_post_func(): raise ValueError('Invalid post function!') self._shutdown = False self._path_prefix = path_prefix + self._callback_retry_attempts: Dict[int, int] = {} self._db_manager = None self._waiting_task_queue = None @@ -62,6 +74,8 @@ def _lazy_init(self): table_name=FINISHED_TASK_QUEUE_TABLE_INFO['name'], columns=FINISHED_TASK_QUEUE_TABLE_INFO['columns'], db_config=self._db_config, + order_by='finished_at', + order_desc=False, ) self._post_func_thread = threading.Thread(target=self.process_finished_task, daemon=True) @@ -88,14 +102,19 @@ def process_finished_task(self): '''process finished task in background thread''' while True: try: - finished_task = self._finished_task_queue.dequeue() + finished_task = self._finished_task_queue.peek() if finished_task: - self._callback( - task_id=finished_task.get('task_id'), - task_status=finished_task.get('task_status'), - error_code=finished_task.get('error_code'), - error_msg=finished_task.get('error_msg') - ) + if not self._is_callback_due(finished_task): + time.sleep(0.5) + continue + try: + self._callback(finished_task) + except Exception as exc: + self._schedule_callback_retry(finished_task, exc) + time.sleep(0.1) + continue + self._finished_task_queue.clear(filter_by={'id': finished_task['id']}) + self._callback_retry_attempts.pop(finished_task['id'], None) time.sleep(0.1) else: time.sleep(1) @@ -103,6 +122,159 @@ def process_finished_task(self): LOG.error(f'[DocumentProcessor] Failed to process finished task: {e}, {traceback.format_exc()}') time.sleep(10) + @staticmethod + def _normalize_callback_url(callback_url: Optional[str]) -> Optional[str]: + if not callback_url: + return None + return callback_url.rstrip('/') + + def set_callback_url(self, callback_url: Optional[str]): + self._callback_url = self._normalize_callback_url(callback_url) + + @staticmethod + def _normalize_queue_datetime(value: Any) -> Optional[datetime]: + if value is None or isinstance(value, datetime): + return value + if isinstance(value, str): + try: + return datetime.fromisoformat(value) + except ValueError: + return None + return None + + def _is_callback_due(self, finished_task: Dict[str, Any]) -> bool: + finished_at = self._normalize_queue_datetime(finished_task.get('finished_at')) + return finished_at is None or finished_at <= datetime.now() + + @staticmethod + def _load_task_context(finished_task: Dict[str, Any]) -> Dict[str, Any]: + task_context_json = finished_task.get('task_context_json') + if not task_context_json: + raise ValueError('task_context_json is missing in finished task queue') + try: + task_context = json.loads(task_context_json) + except json.JSONDecodeError as exc: + raise ValueError(f'invalid task_context_json: {exc}') from exc + if not isinstance(task_context, dict): + raise ValueError('task_context_json must decode to dict') + return task_context + + def _drop_callback_task(self, finished_task: Dict[str, Any], exc: Exception, attempt: int, reason: str): + LOG.error('[DocumentProcessor] Callback delivery dropped queue item.' + f' queue_id={finished_task.get("id")}' + f' task_id={finished_task.get("task_id")}' + f' task_status={finished_task.get("task_status")}' + f' reason={reason}' + f' attempts={attempt}' + f' callback_url={finished_task.get("callback_url")}' + f' task_context_json={finished_task.get("task_context_json")}' + f' error={type(exc).__name__}: {exc}') + self._finished_task_queue.clear(filter_by={'id': finished_task['id']}) + self._callback_retry_attempts.pop(finished_task['id'], None) + + @staticmethod + def _parse_retry_after_seconds(value: Optional[str]) -> Optional[float]: + if not value: + return None + try: + return max(float(value), 0.0) + except (TypeError, ValueError): + pass + try: + retry_at = parsedate_to_datetime(value) + except (TypeError, ValueError, IndexError, OverflowError): + return None + if retry_at.tzinfo is not None: + delay = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() + else: + delay = (retry_at - datetime.now()).total_seconds() + return max(delay, 0.0) + + def _should_retry_callback_error(self, exc: Exception) -> Tuple[bool, Optional[float], str]: + if isinstance(exc, ValueError): + return False, None, 'invalid_callback_payload' + if isinstance(exc, (requests.Timeout, requests.ConnectionError)): + return True, None, 'transient_network_error' + if isinstance(exc, requests.HTTPError): + response = exc.response + if response is None: + return True, None, 'http_error_without_response' + status_code = response.status_code + if status_code in (408, 425): + return True, None, f'http_{status_code}' + if status_code == 429: + return True, self._parse_retry_after_seconds(response.headers.get('Retry-After')), 'http_429' + if status_code >= 500: + return True, None, f'http_{status_code}' + if 400 <= status_code < 500: + return False, None, f'http_{status_code}' + return True, None, type(exc).__name__ + + def _schedule_callback_retry(self, finished_task: Dict[str, Any], exc: Exception) -> bool: + queue_id = finished_task['id'] + attempt = self._callback_retry_attempts.get(queue_id, 0) + 1 + self._callback_retry_attempts[queue_id] = attempt + should_retry, retry_after_seconds, reason = self._should_retry_callback_error(exc) + if not should_retry: + self._drop_callback_task(finished_task, exc, attempt, reason) + return False + if attempt >= CALLBACK_RETRY_MAX_ATTEMPTS: + self._drop_callback_task(finished_task, exc, attempt, 'retry_exhausted') + return False + delay = CALLBACK_RETRY_MIN_INTERVAL * (2 ** (attempt - 1)) + delay *= random.uniform(0.8, 1.2) + if retry_after_seconds is not None: + delay = max(delay, retry_after_seconds) + delay = min(delay, CALLBACK_RETRY_MAX_INTERVAL) + retry_at = datetime.now() + timedelta(seconds=delay) + self._finished_task_queue.update(filter_by={'id': queue_id}, finished_at=retry_at) + LOG.warning(f'[DocumentProcessor] Callback delivery failed for queue_id={queue_id},' + f' task_id={finished_task.get("task_id")}, retry in {delay:.1f}s' + f' (attempt={attempt}/{CALLBACK_RETRY_MAX_ATTEMPTS})' + f' reason={reason}' + f' error={type(exc).__name__}: {exc}') + return True + + def _resolve_callback_url(self, payload: Dict[str, Any]) -> Optional[str]: + return self._normalize_callback_url( + payload.get('callback_url') or payload.get('feedback_url') or self._callback_url + ) + + def _default_post_func(self, finished_task: Dict[str, Any]): + task_id = finished_task.get('task_id') + task_status = finished_task.get('task_status') + error_code = finished_task.get('error_code') + error_msg = finished_task.get('error_msg') + callback_url = self._normalize_callback_url(finished_task.get('callback_url')) + if not callback_url: + raise ValueError(f'callback_url is missing for task {task_id}') + task_context = self._load_task_context(finished_task) + + base_payload = {'task_type': task_context.get('task_type'), + 'kb_id': task_context.get('kb_id'), + 'algo_id': task_context.get('algo_id')} + items = task_context.get('items') or [{}] + for index, item in enumerate(items): + callback_payload = { + 'callback_id': str(uuid5(NAMESPACE_URL, f'{task_id}:{task_status}:{index}')), + 'task_id': task_id, + 'task_status': task_status, + 'payload': {k: v for k, v in {**base_payload, **item}.items() if v is not None}, + } + for field in ('task_type', 'kb_id', 'algo_id'): + if base_payload.get(field) is not None: + callback_payload[field] = base_payload[field] + if item.get('doc_id') is not None: + callback_payload['doc_id'] = item['doc_id'] + if error_code is not None: + callback_payload['error_code'] = error_code + if error_msg is not None: + callback_payload['error_msg'] = error_msg + + response = requests.post(callback_url, json=callback_payload, timeout=8) + response.raise_for_status() + return True + def register_algorithm(self, name: str, store: _DocumentStore, reader: DirectoryReader, node_groups: Dict[str, Dict], schema_extractor: Optional[SchemaExtractor] = None, display_name: Optional[str] = None, description: Optional[str] = None): @@ -291,6 +463,10 @@ def add_doc(self, request: AddDocRequest): task_type = TaskType.DOC_REPARSE.value else: raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + payload = request.model_dump() + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: @@ -332,6 +508,9 @@ def update_meta(self, request: UpdateMetaRequest): algorithm = self._get_algo(algo_id) if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_UPDATE_META.value @@ -374,6 +553,9 @@ def delete_doc(self, request: DeleteDocRequest): if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_DELETE.value @@ -434,8 +616,12 @@ def cancel(self, request: CancelTaskRequest): def _check_post_func(self) -> bool: '''assert post function is callable and params include task_id, task_status, error_code, error_msg''' if not self._post_func: - LOG.warning('[DocumentProcessor] No post function configured,' - ' task status callback will not be performed!') + if self._callback_url: + LOG.info('[DocumentProcessor] No custom post function configured,' + ' using built-in HTTP callback') + else: + LOG.warning('[DocumentProcessor] No custom post function configured, built-in HTTP callback' + ' will only run when callback_url or feedback_url is provided in task request') return True if not callable(self._post_func): LOG.error('[DocumentProcessor] Post function is not callable') @@ -450,18 +636,30 @@ def _check_post_func(self) -> bool: return False return True - def _callback(self, task_id: str, task_status: str = None, error_code: str = None, error_msg: str = None): + def _callback(self, finished_task: Dict[str, Any]): '''callback to service''' - message = f'Task {task_id} finished with status: {task_status}.' + task_id = finished_task.get('task_id') + task_status = finished_task.get('task_status') + error_code = finished_task.get('error_code') + error_msg = finished_task.get('error_msg') + message = f'Task {task_id} callback status: {task_status}.' if error_msg: message += f' Error code: {error_code}, error_msg: {error_msg}.' LOG.info(f'[DocumentProcessor] {message}') - if self._post_func: - try: + try: + if self._post_func: self._post_func(task_id, task_status, error_code, error_msg) - except Exception as e: - LOG.error(f'[DocumentProcessor] Failed to call post function: {e}, {traceback.format_exc()}') + else: + self._default_post_func(finished_task) + except Exception as e: + LOG.error(f'[DocumentProcessor] Failed to call post function: {e}, {traceback.format_exc()}') + if self._post_func: + try: + self._default_post_func(finished_task) + except Exception: + raise e + else: raise e def __call__(self, func_name: str, *args, **kwargs): @@ -470,14 +668,15 @@ def __call__(self, func_name: str, *args, **kwargs): def __init__(self, port: int = None, url: str = None, num_workers: int = 1, db_config: Optional[Dict[str, Any]] = None, launcher: Optional[Launcher] = None, post_func: Optional[Callable] = None, - path_prefix: Optional[str] = None): + path_prefix: Optional[str] = None, callback_url: Optional[str] = None): super().__init__() self._raw_impl = None # save the reference of the original Impl object self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') if not url: # create the Impl object (lazy loading, no threads created) self._raw_impl = DocumentProcessor._Impl(num_workers=num_workers, db_config=self._db_config, - post_func=post_func, path_prefix=path_prefix) + post_func=post_func, path_prefix=path_prefix, + callback_url=callback_url) self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) else: self._impl = UrlModule(url=ensure_call_endpoint(url)) @@ -502,6 +701,11 @@ def wait(self): return impl.wait() LOG.warning('[DocumentProcessor] wait() is no-op in UrlModule mode') + def set_callback_url(self, callback_url: Optional[str]): + if isinstance(self._impl, UrlModule): + raise RuntimeError('set_callback_url is only supported in local server mode') + return self._dispatch('set_callback_url', callback_url) + def _dispatch(self, method: str, *args, **kwargs): try: impl = self._impl diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 4ce0c364a..7ef7c7550 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -170,8 +170,35 @@ def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: d f'error: {e}') raise e + @staticmethod + def _resolve_callback_url(payload: dict): + return payload.get('callback_url') or payload.get('feedback_url') + + @staticmethod + def _build_task_context(task_type: str, payload: dict) -> dict: + items = [] + if task_type in (TaskType.DOC_ADD.value, TaskType.DOC_REPARSE.value, TaskType.DOC_UPDATE_META.value): + file_infos = payload.get('file_infos') or [] + items = [{ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'metadata': file_info.get('metadata'), + 'reparse_group': file_info.get('reparse_group'), + } for file_info in file_infos] + elif task_type == TaskType.DOC_DELETE.value: + items = [{'doc_id': doc_id} for doc_id in (payload.get('doc_ids') or [])] + if not items: + items = [{}] + return { + 'task_type': task_type, + 'kb_id': payload.get('kb_id'), + 'algo_id': payload.get('algo_id'), + 'items': items, + } + def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: TaskStatus, - error_code: str = None, error_msg: str = None): + error_code: str = None, error_msg: str = None, + callback_url: str = None, task_context_json: str = None): try: self._lazy_init() self._finished_task_queue.enqueue( @@ -179,11 +206,15 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task task_type=task_type, task_status=task_status.value, finished_at=datetime.now(), + callback_url=callback_url, + task_context_json=task_context_json, error_code=error_code if error_code else '200', error_msg=error_msg if error_msg else 'success' ) - if task_status == TaskStatus.FINISHED: - LOG.info(f'[DocumentProcessorWorker._Impl] Task {task_id} finished successfully') + if task_status == TaskStatus.WORKING: + LOG.info(f'[DocumentProcessorWorker._Impl] Task {task_id} started') + elif task_status == TaskStatus.SUCCESS: + LOG.info(f'[DocumentProcessorWorker._Impl] Task {task_id} completed successfully') else: LOG.error(f'[DocumentProcessorWorker._Impl] Task {task_id} completed with status {task_status}:' f' {error_msg}') @@ -207,9 +238,18 @@ def _worker_impl(self): if not algo_id: raise ValueError(f'[DocumentProcessorWorker._Impl] task_id {task_id} is missing algo_id in ' f'payload: {payload}') + callback_url = self._resolve_callback_url(payload) + task_context_json = json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False) LOG.info(f'[DocumentProcessorWorker._Impl] Start processing task {task_id}, type: {task_type},' f' algo_id: {algo_id}') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.WORKING, + callback_url=callback_url, + task_context_json=task_context_json, + ) processor = self._get_or_create_processor(algo_id) if task_type == TaskType.DOC_ADD.value: @@ -223,14 +263,30 @@ def _worker_impl(self): else: raise ValueError(f'[DocumentProcessorWorker._Impl] Unknown task type: {task_type}') - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, - error_code='200', error_msg='success') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.SUCCESS, + error_code='200', + error_msg='success', + callback_url=callback_url, + task_context_json=task_context_json, + ) except Exception as e: LOG.error(f'[DocumentProcessorWorker._Impl] Failed to run task {task_id}: {e},' f' {traceback.format_exc()}') if task_id and task_type: - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, - error_code=type(e).__name__, error_msg=str(e)) + callback_url = locals().get('callback_url') + task_context_json = locals().get('task_context_json') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.FAILED, + error_code=type(e).__name__, + error_msg=str(e), + callback_url=callback_url, + task_context_json=task_context_json, + ) time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index e9bcaddfc..21db003e1 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -1,4 +1,6 @@ +import hashlib import os +import re import traceback import lazyllm from collections import defaultdict @@ -20,6 +22,9 @@ class _DocumentStore(object): + _COLLECTION_NAME_PATTERN = re.compile(r'[^a-z0-9_]+') + _COLLECTION_NAME_MAX_LEN = 255 + def __init__(self, algo_name: str, store: Union[Dict, LazyLLMStoreBase], group_embed_keys: Optional[Dict[str, Set[str]]] = None, embed: Optional[Dict[str, Callable]] = None, embed_dims: Optional[Dict[str, int]] = None, embed_datatypes: Optional[Dict[str, DataType]] = None, @@ -439,4 +444,16 @@ def _deserialize_node(self, data: dict, score: Optional[float] = None) -> DocNod return node.with_sim_score(score) if score else node def _gen_collection_name(self, group: str) -> str: - return f'col_{self._algo_name}_{group}'.lower() + raw_name = f'col_{self._algo_name}_{group}'.lower() + normalized = self._COLLECTION_NAME_PATTERN.sub('_', raw_name).strip('_') + if not normalized: + normalized = 'col' + if normalized[0].isdigit(): + normalized = f'col_{normalized}' + if normalized == raw_name and len(normalized) <= self._COLLECTION_NAME_MAX_LEN: + return normalized + + digest = hashlib.sha1(raw_name.encode()).hexdigest()[:12] + max_prefix_len = self._COLLECTION_NAME_MAX_LEN - len(digest) - 1 + prefix = normalized[:max_prefix_len].rstrip('_') or 'col' + return f'{prefix}_{digest}' diff --git a/tests/basic_tests/RAG/test_doc_processor.py b/tests/basic_tests/RAG/test_doc_processor.py index c6980faf9..30d0a91ae 100644 --- a/tests/basic_tests/RAG/test_doc_processor.py +++ b/tests/basic_tests/RAG/test_doc_processor.py @@ -9,7 +9,7 @@ from lazyllm.tools.rag.parsing_service.base import TaskStatus from lazyllm import Document, Retriever -STATIC_STATUS = [TaskStatus.FINISHED.value, TaskStatus.FAILED.value, TaskStatus.CANCELED.value] +STATIC_STATUS = [TaskStatus.SUCCESS.value, TaskStatus.FAILED.value, TaskStatus.CANCELED.value] records = [] def post_func_sample(task_id: str, task_status: str, error_code: str = None, error_msg: str = None): diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py new file mode 100644 index 000000000..a7e3cc593 --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_service_mock.py @@ -0,0 +1,966 @@ +import io +import os +import socket +import shutil +import tempfile +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from uuid import uuid4 + +import pytest +import requests + +from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, CallbackEventType, DeleteRequest, DocServiceError, DocStatus, KbUpdateRequest, ReparseRequest, + SourceType, TaskCallbackRequest, UploadRequest, +) +from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient +from lazyllm.tools.rag.parsing_service.base import TaskType +from lazyllm.tools.rag.utils import BaseResponse + + +@pytest.mark.skip_on_win +class TestDocServiceMock: + @staticmethod + def _ensure_bindable(): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + sock.bind(('127.0.0.1', 0)) + except OSError as exc: + if 'operation not permitted' in str(exc).lower(): + pytest.skip('Socket bind is not permitted in current environment') + raise + finally: + sock.close() + + @classmethod + def setup_class(cls): + cls._ensure_bindable() + cls._parser_url = os.getenv('LAZYLLM_DOC_SERVICE_TEST_PARSER_URL') + if not cls._parser_url: + pytest.skip('LAZYLLM_DOC_SERVICE_TEST_PARSER_URL is required for real parser integration tests') + cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_') + cls._storage_dir = os.path.join(cls._tmp_dir, 'uploads') + os.makedirs(cls._storage_dir, exist_ok=True) + + cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') + with open(cls._seed_path, 'w', encoding='utf-8') as f: + f.write('seed content') + + cls._db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(cls._tmp_dir, 'doc_service.db'), + } + cls.server = DocServer( + db_config=cls._db_config, + parser_url=cls._parser_url, + storage_dir=cls._storage_dir, + ) + cls.server.start() + cls.base_url = cls.server._impl._url.rsplit('/', 1)[0] + deadline = time.time() + 10 + while time.time() < deadline: + try: + resp = requests.get(f'{cls.base_url}/v1/health', timeout=3) + if resp.status_code == 200: + break + except Exception: + pass + time.sleep(0.2) + + @classmethod + def teardown_class(cls): + cls.server.stop() + shutil.rmtree(cls._tmp_dir, ignore_errors=True) + + def _wait_task(self, task_id, target_statuses, timeout=8): + deadline = time.time() + timeout + last = None + while time.time() < deadline: + resp = requests.get(f'{self.base_url}/v1/tasks/{task_id}', timeout=5) + assert resp.status_code == 200 + last = resp.json()['data'] + if last['status'] in target_statuses: + return last + time.sleep(0.1) + raise AssertionError(f'task {task_id} not finished in time, last={last}') + + def test_p0_endpoints_and_core_flows(self): + health = requests.get(f'{self.base_url}/v1/health', timeout=5) + assert health.status_code == 200 + + kb_create = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_a'}, timeout=5) + assert kb_create.status_code == 200 + assert kb_create.json()['data']['kb_id'] == 'kb_a' + + kb_list = requests.get(f'{self.base_url}/v1/kbs', timeout=5) + assert kb_list.status_code == 200 + assert any(item['kb_id'] == 'kb_a' for item in kb_list.json()['data']['items']) + + algo_list = requests.get(f'{self.base_url}/v1/algo/list', timeout=5) + assert algo_list.status_code == 200 + assert any(item['algo_id'] == '__default__' for item in algo_list.json()['data']) + + algo_groups = requests.get(f'{self.base_url}/v1/algo/__default__/groups', timeout=5) + assert algo_groups.status_code == 200 + assert len(algo_groups.json()['data']) > 0 + + upload = requests.post( + f'{self.base_url}/v1/docs/upload', + params={'kb_id': 'kb_a', 'algo_id': '__default__'}, + files=[('files', ('upload.txt', io.BytesIO(b'upload content'), 'text/plain'))], + timeout=8, + ) + assert upload.status_code == 200 + upload_item = upload.json()['data']['items'][0] + doc_upload = upload_item['doc_id'] + upload_task = upload_item['task_id'] + self._wait_task(upload_task, {'SUCCESS'}) + + add = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_a', + 'algo_id': '__default__', + 'items': [{'file_path': self._seed_path, 'doc_id': 'seed-doc-1', 'metadata': {'owner': 'u1'}}], + }, + timeout=8, + ) + assert add.status_code == 200 + add_item = add.json()['data']['items'][0] + doc_add = add_item['doc_id'] + add_task = add_item['task_id'] + self._wait_task(add_task, {'SUCCESS'}) + + meta_patch = requests.post( + f'{self.base_url}/v1/docs/metadata/patch', + json={ + 'kb_id': 'kb_a', + 'algo_id': '__default__', + 'items': [{'doc_id': doc_add, 'patch': {'tag': 'patched'}}], + }, + timeout=8, + ) + assert meta_patch.status_code == 200 + meta_task = meta_patch.json()['data']['items'][0]['task_id'] + self._wait_task(meta_task, {'SUCCESS'}) + + reparse = requests.post( + f'{self.base_url}/v1/docs/reparse', + json={'kb_id': 'kb_a', 'algo_id': '__default__', 'doc_ids': [doc_add]}, + timeout=8, + ) + assert reparse.status_code == 200 + reparse_task = reparse.json()['data']['task_ids'][0] + self._wait_task(reparse_task, {'SUCCESS'}) + + kb_b = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_b'}, timeout=5) + assert kb_b.status_code == 200 + + transfer = requests.post( + f'{self.base_url}/v1/docs/transfer', + json={ + 'items': [ + { + 'doc_id': doc_add, + 'source_kb_id': 'kb_a', + 'source_algo_id': '__default__', + 'target_kb_id': 'kb_b', + 'target_algo_id': '__default__', + 'mode': 'copy', + } + ] + }, + timeout=8, + ) + assert transfer.status_code == 200 + transfer_task = transfer.json()['data']['items'][0]['task_id'] + self._wait_task(transfer_task, {'SUCCESS'}) + + docs = requests.get( + f'{self.base_url}/v1/docs', + params={'kb_id': 'kb_a', 'include_deleted_or_canceled': True, 'keyword': 'seed'}, + timeout=8, + ) + assert docs.status_code == 200 + assert docs.json()['data']['total'] >= 1 + + doc_detail = requests.get(f'{self.base_url}/v1/docs/{doc_add}', timeout=8) + assert doc_detail.status_code == 200 + assert doc_detail.json()['data']['doc']['metadata'].get('tag') == 'patched' + + tasks = requests.get(f'{self.base_url}/v1/tasks', params={'status': ['SUCCESS', 'WAITING']}, timeout=8) + assert tasks.status_code == 200 + assert tasks.json()['data']['total'] >= 1 + + task_detail = requests.get(f'{self.base_url}/v1/tasks/{reparse_task}', timeout=8) + assert task_detail.status_code == 200 + + cancel = requests.post(f'{self.base_url}/v1/tasks/cancel', json={'task_id': reparse_task}, timeout=8) + assert cancel.status_code == 200 + assert cancel.json()['data']['task_id'] == reparse_task + + delete = requests.post( + f'{self.base_url}/v1/docs/delete', + json={'kb_id': 'kb_a', 'algo_id': '__default__', 'doc_ids': [doc_upload]}, + timeout=8, + ) + assert delete.status_code == 200 + delete_task = delete.json()['data']['items'][0]['task_id'] + self._wait_task(delete_task, {'DELETED'}) + + docs_filtered = requests.get( + f'{self.base_url}/v1/docs', + params={'kb_id': 'kb_a', 'include_deleted_or_canceled': False}, + timeout=8, + ) + assert docs_filtered.status_code == 200 + + cb = requests.post( + f'{self.base_url}/v1/internal/callbacks/tasks', + json={ + 'task_id': 'non-exist-task', + 'event_type': 'FINISH', + 'status': 'SUCCESS', + 'payload': {'task_type': 'DOC_ADD', 'doc_id': 'nope', 'kb_id': 'kb_a', 'algo_id': '__default__'}, + }, + timeout=8, + ) + assert cb.status_code == 200 + assert cb.json()['data']['ack'] is True + + kb_delete = requests.delete(f'{self.base_url}/v1/kbs/kb_a', timeout=8) + assert kb_delete.status_code == 200 + + def test_document_manager_supports_doc_server_port(self): + from lazyllm import Document + + tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_port_') + storage_dir = os.path.join(tmp_dir, 'uploads') + os.makedirs(storage_dir, exist_ok=True) + fixed_port = 18898 + doc = Document(dataset_path=storage_dir, manager=True, doc_server_port=fixed_port, name='doc_port_test') + try: + self._ensure_bindable() + doc.start() + base_url = doc.manager.url.rsplit('/', 1)[0] + assert base_url.endswith(f':{fixed_port}') + health = requests.get(f'{base_url}/v1/health', timeout=5) + assert health.status_code == 200 + finally: + doc.stop() + shutil.rmtree(tmp_dir, ignore_errors=True) + + def test_missing_p0_endpoints_exist(self): + kb_create = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_endpoints'}, timeout=5) + assert kb_create.status_code == 200 + + chunks = requests.get(f'{self.base_url}/v1/chunks', timeout=5) + assert chunks.status_code == 200 + assert chunks.json()['data']['items'] == [] + + algorithms = requests.get(f'{self.base_url}/v1/algorithms', timeout=5) + assert algorithms.status_code == 200 + assert len(algorithms.json()['data']['items']) >= 1 + + algo_info = requests.post( + f'{self.base_url}/v1/algorithms/info', json={'algo_id': '__default__'}, timeout=5, + ) + assert algo_info.status_code == 200 + assert algo_info.json()['data']['algo_id'] == '__default__' + + add = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_endpoints', + 'algo_id': '__default__', + 'items': [{'file_path': self._seed_path, 'doc_id': 'seed-doc-endpoints'}], + }, + timeout=8, + ) + assert add.status_code == 200 + task_id = add.json()['data']['items'][0]['task_id'] + + task_info = requests.post(f'{self.base_url}/v1/tasks/info', json={'task_id': task_id}, timeout=5) + assert task_info.status_code == 200 + assert task_info.json()['data']['task_id'] == task_id + + task_batch = requests.post(f'{self.base_url}/v1/tasks/batch', json={'task_ids': [task_id]}, timeout=5) + assert task_batch.status_code == 200 + assert len(task_batch.json()['data']['items']) == 1 + + kb_delete = requests.delete(f'{self.base_url}/v1/kbs', json={'kb_ids': ['kb_endpoints']}, timeout=8) + assert kb_delete.status_code == 200 + assert len(kb_delete.json()['data']['items']) == 1 + + def test_idempotency_replay_and_conflict(self): + file_path = os.path.join(self._tmp_dir, 'idem.txt') + with open(file_path, 'w', encoding='utf-8') as f: + f.write('idempotent content') + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_idem'}, timeout=5) + assert create_kb.status_code == 200 + + payload = { + 'kb_id': 'kb_idem', + 'algo_id': '__default__', + 'idempotency_key': 'idem-add-key', + 'items': [{'file_path': file_path, 'doc_id': 'idem-doc-1'}], + } + first = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) + second = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) + assert first.status_code == 200 + assert second.status_code == 200 + assert first.json()['data']['items'][0]['task_id'] == second.json()['data']['items'][0]['task_id'] + + conflict_payload = dict(payload) + conflict_payload['items'] = [{'file_path': file_path, 'doc_id': 'idem-doc-2'}] + conflict = requests.post(f'{self.base_url}/v1/docs/add', json=conflict_payload, timeout=8) + assert conflict.status_code == 409 + assert conflict.json()['data']['biz_code'] == 'E_IDEMPOTENCY_CONFLICT' + + def test_upload_idempotency_replay_and_conflict(self): + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_upload_idem'}, timeout=5) + assert create_kb.status_code == 200 + + params = { + 'kb_id': 'kb_upload_idem', + 'algo_id': '__default__', + 'idempotency_key': 'idem-upload-key', + } + first = requests.post( + f'{self.base_url}/v1/docs/upload', + params=params, + files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem content'), 'text/plain'))], + timeout=8, + ) + second = requests.post( + f'{self.base_url}/v1/docs/upload', + params=params, + files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem content'), 'text/plain'))], + timeout=8, + ) + assert first.status_code == 200 + assert second.status_code == 200 + assert first.json()['data']['items'][0]['task_id'] == second.json()['data']['items'][0]['task_id'] + + conflict = requests.post( + f'{self.base_url}/v1/docs/upload', + params=params, + files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem changed'), 'text/plain'))], + timeout=8, + ) + assert conflict.status_code == 409 + assert conflict.json()['data']['biz_code'] == 'E_IDEMPOTENCY_CONFLICT' + + def test_upload_same_filename_does_not_override_existing_file(self): + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_same_name'}, timeout=5) + assert create_kb.status_code == 200 + + first = requests.post( + f'{self.base_url}/v1/docs/upload', + params={'kb_id': 'kb_same_name', 'algo_id': '__default__'}, + files=[('files', ('same-name.txt', io.BytesIO(b'first content'), 'text/plain'))], + timeout=8, + ) + second = requests.post( + f'{self.base_url}/v1/docs/upload', + params={'kb_id': 'kb_same_name', 'algo_id': '__default__'}, + files=[('files', ('same-name.txt', io.BytesIO(b'second content'), 'text/plain'))], + timeout=8, + ) + assert first.status_code == 200 + assert second.status_code == 200 + + first_item = first.json()['data']['items'][0] + second_item = second.json()['data']['items'][0] + assert first_item['doc_id'] != second_item['doc_id'] + self._wait_task(first_item['task_id'], {'SUCCESS'}) + self._wait_task(second_item['task_id'], {'SUCCESS'}) + + first_detail = requests.get(f'{self.base_url}/v1/docs/{first_item["doc_id"]}', timeout=8) + second_detail = requests.get(f'{self.base_url}/v1/docs/{second_item["doc_id"]}', timeout=8) + assert first_detail.status_code == 200 + assert second_detail.status_code == 200 + assert first_detail.json()['data']['doc']['path'] != second_detail.json()['data']['doc']['path'] + + def test_idempotency_atomic_claim(self): + file_path = os.path.join(self._tmp_dir, 'idem_atomic.txt') + with open(file_path, 'w', encoding='utf-8') as f: + f.write('idempotent atomic content') + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_idem_atomic'}, timeout=5) + assert create_kb.status_code == 200 + + payload = { + 'kb_id': 'kb_idem_atomic', + 'algo_id': '__default__', + 'idempotency_key': 'idem-atomic-key', + 'items': [{'file_path': file_path, 'doc_id': 'idem-atomic-doc'}], + } + + def _send(): + return requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) + + with ThreadPoolExecutor(max_workers=6) as pool: + responses = list(pool.map(lambda _: _send(), range(6))) + + statuses = [resp.status_code for resp in responses] + assert all(status in (200, 409) for status in statuses) + success_payloads = [resp.json()['data'] for resp in responses if resp.status_code == 200] + unique_task_ids = {item['items'][0]['task_id'] for item in success_payloads} + assert len(unique_task_ids) == 1 + for resp in responses: + if resp.status_code == 409: + assert resp.json()['data']['biz_code'] in {'E_IDEMPOTENCY_IN_PROGRESS', 'E_IDEMPOTENCY_CONFLICT'} + + replay = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) + assert replay.status_code == 200 + assert replay.json()['data']['items'][0]['task_id'] in unique_task_ids + + def test_illegal_state_transition(self): + file_path = os.path.join(self._tmp_dir, 'illegal.txt') + with open(file_path, 'w', encoding='utf-8') as f: + f.write('illegal transition content') + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_illegal'}, timeout=5) + assert create_kb.status_code == 200 + + add = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_illegal', + 'algo_id': '__default__', + 'items': [{'file_path': file_path, 'doc_id': 'illegal-doc-1'}], + }, + timeout=8, + ) + assert add.status_code == 200 + doc_id = add.json()['data']['items'][0]['doc_id'] + task_id = add.json()['data']['items'][0]['task_id'] + self._wait_task(task_id, {'SUCCESS'}) + + delete = requests.post( + f'{self.base_url}/v1/docs/delete', + json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=8, + ) + assert delete.status_code == 200 + + reparse_while_deleting = requests.post( + f'{self.base_url}/v1/docs/reparse', + json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=8, + ) + assert reparse_while_deleting.status_code == 409 + assert reparse_while_deleting.json()['data']['biz_code'] == 'E_STATE_CONFLICT' + + delete_again = requests.post( + f'{self.base_url}/v1/docs/delete', + json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=8, + ) + assert delete_again.status_code == 409 + + add_again = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_illegal', + 'algo_id': '__default__', + 'items': [{'file_path': file_path, 'doc_id': doc_id}], + }, + timeout=8, + ) + assert add_again.status_code == 409 + + def test_kb_algo_binding_and_transfer_validation(self): + create_kb = requests.post( + f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_bind', 'algo_id': '__default__'}, timeout=5, + ) + assert create_kb.status_code == 200 + + rebind = requests.post( + f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_bind', 'algo_id': 'another_algo'}, timeout=5, + ) + assert rebind.status_code == 409 + assert rebind.json()['data']['biz_code'] == 'E_STATE_CONFLICT' + + file_path = os.path.join(self._tmp_dir, 'binding.txt') + with open(file_path, 'w', encoding='utf-8') as f: + f.write('binding content') + + mismatch = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_bind', + 'algo_id': 'another_algo', + 'items': [{'file_path': file_path, 'doc_id': 'bind-doc'}], + }, + timeout=8, + ) + assert mismatch.status_code == 400 + assert mismatch.json()['data']['biz_code'] == 'E_INVALID_PARAM' + + add = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_bind', + 'algo_id': '__default__', + 'items': [{'file_path': file_path, 'doc_id': 'bind-doc'}], + }, + timeout=8, + ) + assert add.status_code == 200 + doc_id = add.json()['data']['items'][0]['doc_id'] + self._wait_task(add.json()['data']['items'][0]['task_id'], {'SUCCESS'}) + + invalid_transfer = requests.post( + f'{self.base_url}/v1/docs/transfer', + json={ + 'items': [{ + 'doc_id': doc_id, + 'source_kb_id': 'kb_bind', + 'source_algo_id': '__default__', + 'target_kb_id': 'kb_bind', + 'target_algo_id': '__default__', + 'mode': 'invalid', + }] + }, + timeout=8, + ) + assert invalid_transfer.status_code == 400 + assert invalid_transfer.json()['data']['biz_code'] == 'E_INVALID_PARAM' + + def test_stale_callback_ignored(self): + file_path = os.path.join(self._tmp_dir, 'stale.txt') + with open(file_path, 'w', encoding='utf-8') as f: + f.write('stale callback content') + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_stale'}, timeout=5) + assert create_kb.status_code == 200 + + add = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_stale', + 'algo_id': '__default__', + 'items': [{'file_path': file_path, 'doc_id': 'stale-doc-1'}], + }, + timeout=8, + ) + assert add.status_code == 200 + doc_id = add.json()['data']['items'][0]['doc_id'] + self._wait_task(add.json()['data']['items'][0]['task_id'], {'SUCCESS'}) + + first = requests.post( + f'{self.base_url}/v1/docs/reparse', + json={'kb_id': 'kb_stale', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=8, + ) + assert first.status_code == 200 + first_task_id = first.json()['data']['task_ids'][0] + + second = requests.post( + f'{self.base_url}/v1/docs/reparse', + json={'kb_id': 'kb_stale', 'algo_id': '__default__', 'doc_ids': [doc_id]}, + timeout=8, + ) + assert second.status_code == 200 + second_task_id = second.json()['data']['task_ids'][0] + assert first_task_id != second_task_id + + stale = requests.post( + f'{self.base_url}/v1/internal/callbacks/tasks', + json={ + 'callback_id': 'stale-callback-1', + 'task_id': first_task_id, + 'event_type': 'FINISH', + 'status': 'SUCCESS', + }, + timeout=8, + ) + assert stale.status_code == 200 + assert stale.json()['data']['ignored_reason'] == 'stale_task_callback' + + duplicate = requests.post( + f'{self.base_url}/v1/internal/callbacks/tasks', + json={ + 'callback_id': 'stale-callback-1', + 'task_id': first_task_id, + 'event_type': 'FINISH', + 'status': 'SUCCESS', + }, + timeout=8, + ) + assert duplicate.status_code == 200 + assert duplicate.json()['data']['deduped'] is True + + def test_get_doc_404_is_wrapped(self): + resp = requests.get(f'{self.base_url}/v1/docs/not-exists-doc', timeout=5) + assert resp.status_code == 404 + body = resp.json() + assert body['code'] == 404 + assert body['data']['biz_code'] == 'E_NOT_FOUND' + + def test_delete_kbs_empty_payload_returns_400(self): + resp = requests.delete(f'{self.base_url}/v1/kbs', json={'kb_ids': []}, timeout=5) + assert resp.status_code == 400 + assert resp.json()['data']['biz_code'] == 'E_INVALID_PARAM' + + def test_kb_update_pagination_and_batch_query(self): + first = requests.post( + f'{self.base_url}/v1/kbs', + json={'kb_id': 'kb_page_1', 'display_name': 'Page 1', 'algo_id': '__default__'}, + timeout=5, + ) + second = requests.post( + f'{self.base_url}/v1/kbs', + json={'kb_id': 'kb_page_2', 'display_name': 'Page 2', 'algo_id': '__default__'}, + timeout=5, + ) + assert first.status_code == 200 + assert second.status_code == 200 + + paged = requests.get(f'{self.base_url}/v1/kbs', params={'page': 1, 'page_size': 1}, timeout=5) + assert paged.status_code == 200 + paged_data = paged.json()['data'] + assert paged_data['page'] == 1 + assert paged_data['page_size'] == 1 + assert paged_data['total'] >= 2 + assert len(paged_data['items']) == 1 + + detail = requests.get(f'{self.base_url}/v1/kbs/kb_page_1', timeout=5) + assert detail.status_code == 200 + assert detail.json()['data']['algo_id'] == '__default__' + + updated = requests.post( + f'{self.base_url}/v1/kbs/kb_page_1/update', + json={ + 'display_name': 'Page 1 Updated', + 'description': 'updated description', + 'owner_id': 'owner-a', + 'meta': {'scene': 'pagination-test'}, + }, + timeout=5, + ) + assert updated.status_code == 200 + updated_data = updated.json()['data'] + assert updated_data['display_name'] == 'Page 1 Updated' + assert updated_data['meta']['scene'] == 'pagination-test' + + batch = requests.post( + f'{self.base_url}/v1/kbs/batch', + json={'kb_ids': ['kb_page_1', 'kb_missing']}, + timeout=5, + ) + assert batch.status_code == 200 + batch_data = batch.json()['data'] + assert len(batch_data['items']) == 1 + assert batch_data['items'][0]['kb_id'] == 'kb_page_1' + assert batch_data['missing_kb_ids'] == ['kb_missing'] + + +class TestDocServiceMockLocal: + @classmethod + def setup_class(cls): + cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_local_') + cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') + with open(cls._seed_path, 'w', encoding='utf-8') as f: + f.write('local seed content') + cls._db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(cls._tmp_dir, 'doc_service_local.db'), + } + cls.manager = DocManager(db_config=cls._db_config, parser_url='http://parser.test') + cls._pending_task_status = {} + + def _queue_task(task_id: str, final_status: DocStatus): + cls._pending_task_status[task_id] = final_status + + cls.manager._parser_client.add_doc = lambda task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.update_meta = lambda task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.delete_doc = lambda task_id, algo_id, kb_id, doc_id: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.cancel_task = lambda task_id: BaseResponse( + code=200, msg='success', data={'task_id': task_id, 'cancel_status': True} + ) + cls.manager._parser_client.list_algorithms = lambda: BaseResponse( + code=200, msg='success', data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}] + ) + cls.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( + code=200, + msg='success', + data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, + ) + + @classmethod + def teardown_class(cls): + shutil.rmtree(cls._tmp_dir, ignore_errors=True) + + def _wait_task(self, task_id, target_statuses, timeout=8): + deadline = time.time() + timeout + last = None + while time.time() < deadline: + resp = self.manager.get_task(task_id) + assert resp.code == 200 + last = resp.data + if last['status'] in target_statuses: + return last + pending_status = self._pending_task_status.pop(task_id, None) + if pending_status is not None: + self.manager.on_task_callback(TaskCallbackRequest( + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=pending_status, + )) + time.sleep(0.05) + raise AssertionError(f'task {task_id} not finished in time, last={last}') + + def _make_file(self, name: str, content: str): + file_path = os.path.join(self._tmp_dir, name) + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + return file_path + + def test_manager_atomic_idempotency(self): + started = [] + + def handler(): + started.append(time.time()) + time.sleep(0.2) + return {'task_id': str(uuid4())} + + with ThreadPoolExecutor(max_workers=2) as pool: + future = pool.submit(self.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) + time.sleep(0.05) + with pytest.raises(DocServiceError) as exc: + self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + result = future.result(timeout=2) + + assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' + replay = self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + assert len(started) == 1 + assert replay == result + + def test_manager_kb_algo_binding(self): + self.manager.create_kb('kb_local_bind', algo_id='__default__') + file_path = self._make_file('local_bind.txt', 'local bind content') + with pytest.raises(DocServiceError) as exc: + self.manager.upload(UploadRequest( + kb_id='kb_local_bind', + algo_id='wrong_algo', + items=[AddFileItem(file_path=file_path, doc_id='local-bind-doc')], + )) + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_stale_callback_and_state_conflict(self): + self.manager.create_kb('kb_local_stale', algo_id='__default__') + file_path = self._make_file('local_stale.txt', 'local stale content') + uploaded = self.manager.upload(UploadRequest( + kb_id='kb_local_stale', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='local-stale-doc')], + )) + self._wait_task(uploaded[0]['task_id'], {'SUCCESS'}) + first_task_id = self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + ))[0] + second_task_id = self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + ))[0] + stale_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-stale-callback', + task_id=first_task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + )) + assert stale_resp['ignored_reason'] == 'stale_task_callback' + self.manager.delete(DeleteRequest(kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'])) + with pytest.raises(DocServiceError) as exc: + self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + )) + assert exc.value.biz_code == 'E_STATE_CONFLICT' + assert second_task_id != first_task_id + + def test_manager_missing_endpoint_surrogates(self): + self.manager.create_kb('kb_local_info', algo_id='__default__') + file_path = self._make_file('local_info.txt', 'local info content') + uploaded = self.manager.upload(UploadRequest( + kb_id='kb_local_info', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='local-info-doc')], + )) + algorithms = self.manager.list_algorithms_compat() + assert len(algorithms['items']) >= 1 + algo_info = self.manager.get_algorithm_info('__default__') + assert algo_info['algo_id'] == '__default__' + chunks = self.manager.list_chunks() + assert chunks['items'] == [] + tasks_batch = self.manager.get_tasks_batch([uploaded[0]['task_id']]) + assert len(tasks_batch['items']) == 1 + + def test_delete_kbs_empty_list_rejected(self): + with pytest.raises(DocServiceError) as exc: + self.manager.delete_kbs([]) + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_rejects_unknown_kb_algorithm(self): + with pytest.raises(DocServiceError) as exc: + self.manager.create_kb('kb_local_unknown_algo', algo_id='missing_algo') + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_update_kb_can_clear_nullable_fields(self): + self.manager.create_kb( + 'kb_local_clearable', + display_name='Clearable', + description='to be cleared', + owner_id='owner-x', + meta={'tag': 'x'}, + algo_id='__default__', + ) + updated = self.manager.update_kb( + 'kb_local_clearable', + display_name=None, + description=None, + owner_id=None, + meta=None, + explicit_fields={'display_name', 'description', 'owner_id', 'meta'}, + ) + assert updated['display_name'] is None + assert updated['description'] is None + assert updated['owner_id'] is None + assert updated['meta'] == {} + + def test_kb_update_idempotency_payload_distinguishes_omitted_and_null(self): + keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') + clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') + + keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) + clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) + + assert keep_payload != clear_payload + + self.manager.run_idempotent( + '/v1/kbs/kb_local_idem:patch', + 'kb-update-idem', + keep_payload, + lambda: {'kb_id': 'kb_local_idem', 'owner_id': 'kept'}, + ) + with pytest.raises(DocServiceError) as exc: + self.manager.run_idempotent( + '/v1/kbs/kb_local_idem:patch', + 'kb-update-idem', + clear_payload, + lambda: {'kb_id': 'kb_local_idem', 'owner_id': None}, + ) + assert exc.value.biz_code == 'E_IDEMPOTENCY_CONFLICT' + + def test_manager_callback_payload_fallback_and_delete_transition(self): + self.manager.create_kb('kb_local_callback', algo_id='__default__') + file_path = self._make_file('local_callback.txt', 'local callback content') + self.manager._upsert_doc( + doc_id='local-callback-doc', + filename='local_callback.txt', + path=file_path, + metadata={'case': 'callback'}, + source_type=SourceType.EXTERNAL, + ) + self.manager._ensure_kb_document('kb_local_callback', 'local-callback-doc') + queued_at = self.manager._upsert_parse_snapshot( + doc_id='local-callback-doc', + kb_id='kb_local_callback', + algo_id='__default__', + status=DocStatus.DELETING, + task_type=TaskType.DOC_DELETE, + current_task_id='local-delete-task', + queued_at=datetime.now(), + )['queued_at'] + + start_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-delete-start', + task_id='local-delete-task', + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'local-callback-doc', + 'kb_id': 'kb_local_callback', + 'algo_id': '__default__', + }, + )) + assert start_resp['ack'] is True + start_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') + assert start_snapshot['status'] == DocStatus.DELETING.value + assert start_snapshot['queued_at'] == queued_at + + finish_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-delete-finish', + task_id='local-delete-task', + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'local-callback-doc', + 'kb_id': 'kb_local_callback', + 'algo_id': '__default__', + }, + )) + assert finish_resp['ack'] is True + + finish_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') + assert finish_snapshot['status'] == DocStatus.DELETED.value + assert self.manager._has_kb_document('kb_local_callback', 'local-callback-doc') is False + assert self.manager._get_doc('local-callback-doc')['upload_status'] == DocStatus.DELETED.value + + def test_parser_client_algo_endpoint_fallback(self): + client = _ParserClient(parser_url='http://parser.test') + calls = [] + + def fake_get(path, params=None): + del params + calls.append(path) + if path == '/v1/algo/list': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/list': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + } + if path == '/v1/algo/__default__/groups': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/__default__/group/info': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], + } + raise AssertionError(path) + + client._get = fake_get + algo_resp = client.list_algorithms() + group_resp = client.get_algorithm_groups('__default__') + + assert algo_resp.code == 200 + assert group_resp.code == 200 + assert calls == [ + '/v1/algo/list', + '/algo/list', + '/v1/algo/__default__/groups', + '/algo/__default__/group/info', + ] diff --git a/tests/basic_tests/Tools/test_sql_manager.py b/tests/basic_tests/Tools/test_sql_manager.py new file mode 100644 index 000000000..fbd4b4680 --- /dev/null +++ b/tests/basic_tests/Tools/test_sql_manager.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock, patch + +import sqlalchemy +from sqlalchemy.exc import OperationalError + +from lazyllm.tools.sql.sql_manager import SqlManager + + +def _make_unknown_database_error(): + class _UnknownDatabaseError(Exception): + args = (1049, "Unknown database 'lazyllm_doc_task'") + + return OperationalError('SELECT 1', {}, _UnknownDatabaseError()) + + +def _make_conn_cm(conn): + cm = MagicMock() + cm.__enter__.return_value = conn + cm.__exit__.return_value = False + return cm + + +class TestSqlManager(object): + + @patch('lazyllm.tools.sql.sql_manager.sqlalchemy.create_engine') + def test_tidb_create_database_when_missing(self, mock_create_engine): + probe_engine = MagicMock() + probe_engine.connect.side_effect = _make_unknown_database_error() + + admin_conn = MagicMock() + admin_engine = MagicMock() + admin_engine.connect.return_value = _make_conn_cm(admin_conn) + + final_engine = MagicMock() + mock_create_engine.side_effect = [probe_engine, admin_engine, final_engine] + + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + assert sql_manager.engine is final_engine + assert mock_create_engine.call_args_list[1].args[0] == 'mysql+pymysql://root:pwd@127.0.0.1:4000/' + assert str(admin_conn.execute.call_args.args[0]) == 'CREATE DATABASE IF NOT EXISTS `lazyllm_doc_task`' + + @patch('lazyllm.tools.sql.sql_manager.sqlalchemy.create_engine') + def test_tidb_skip_create_database_when_exists(self, mock_create_engine): + probe_engine = MagicMock() + probe_engine.connect.return_value = _make_conn_cm(MagicMock()) + + final_engine = MagicMock() + mock_create_engine.side_effect = [probe_engine, final_engine] + + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + assert sql_manager.engine is final_engine + assert mock_create_engine.call_count == 2 + + def test_tidb_primary_key_string_uses_varchar(self): + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + primary_key_type = sql_manager._sql_type_for('string', is_primary_key=True) + normal_string_type = sql_manager._sql_type_for('string') + + assert isinstance(primary_key_type, sqlalchemy.String) + assert primary_key_type.length == 255 + assert normal_string_type is sqlalchemy.Text From 56cbcf4885209a4179f5f83d3e486d29cdbee9df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 12 Mar 2026 09:40:11 +0800 Subject: [PATCH 12/46] fix api error --- lazyllm/tools/rag/parsing_service/base.py | 9 +++-- lazyllm/tools/rag/parsing_service/server.py | 38 +++++++++++++-------- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 486b67c3f..457d2b73e 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -33,6 +33,9 @@ class DBInfo(BaseModel): options_str: Optional[str] = None +EmptyDBInfo = Annotated[DBInfo | None, BeforeValidator(lambda v: None if v == {} else v)] + + class AddDocRequest(BaseModel): task_id: str = Field(default_factory=lambda: str(uuid4())) algo_id: Optional[str] = '__default__' @@ -40,7 +43,7 @@ class AddDocRequest(BaseModel): file_infos: List[FileInfo] priority: Optional[int] = 0 # NOTE: (db_info, feedback_url) is deprecated, will be removed in the future - db_info: Optional[DBInfo] = None + db_info: EmptyDBInfo = None feedback_url: Optional[str] = None @@ -51,7 +54,7 @@ class UpdateMetaRequest(BaseModel): file_infos: List[FileInfo] priority: Optional[int] = 0 # NOTE: (db_info) is deprecated, will be removed in the future - db_info: Optional[DBInfo] = None + db_info: EmptyDBInfo = None class DeleteDocRequest(BaseModel): @@ -61,7 +64,7 @@ class DeleteDocRequest(BaseModel): doc_ids: List[str] priority: Optional[int] = 0 # NOTE: (db_info) is deprecated, will be removed in the future - db_info: Optional[DBInfo] = None + db_info: EmptyDBInfo = None class CancelTaskRequest(BaseModel): diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index a4fccc684..b438e3414 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -275,26 +275,36 @@ def get_algo_group_info(self, algo_id: str) -> None: if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') try: - algorithm = self._get_algo(algo_id) - if algorithm is None: - raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') - info_pickle_bytes = algorithm.get('info_pickle') - info = load_obj(info_pickle_bytes) - store: _DocumentStore = info['store'] # type: ignore - node_groups = info['node_groups'] - - data = [] - for group_name in store.activated_groups(): - if group_name in node_groups: - group_info = {'name': group_name, 'type': node_groups[group_name].get('group_type'), - 'display_name': node_groups[group_name].get('display_name')} - data.append(group_info) + data = self._get_algo_group_info_data(algo_id) LOG.info(f'[DocumentProcessor] Get group info for {algo_id} success with {data}') return BaseResponse(code=200, msg='success', data=data) + except fastapi.HTTPException: + raise except Exception as e: LOG.error(f'[DocumentProcessor] Failed to get group info: {e}, {traceback.format_exc()}') raise fastapi.HTTPException(status_code=500, detail=f'Failed to get group info: {str(e)}') + @app.get('/group/info') + def get_group_info(self, algo_id: str) -> None: + return self.get_algo_group_info(algo_id) + + def _get_algo_group_info_data(self, algo_id: str) -> List[Dict[str, Any]]: + algorithm = self._get_algo(algo_id) + if algorithm is None: + raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + info_pickle_bytes = algorithm.get('info_pickle') + info = load_obj(info_pickle_bytes) + store: _DocumentStore = info['store'] # type: ignore + node_groups = info['node_groups'] + + data = [] + for group_name in store.activated_groups(): + if group_name in node_groups: + group_info = {'name': group_name, 'type': node_groups[group_name].get('group_type'), + 'display_name': node_groups[group_name].get('display_name')} + data.append(group_info) + return data + @app.post('/doc/add') def add_doc(self, request: AddDocRequest): # noqa: C901 self._lazy_init() From e4e1e4d054edb3e6efdf29dc24d49897ddda94e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 12 Mar 2026 10:52:29 +0800 Subject: [PATCH 13/46] fix node transform fallback --- lazyllm/tools/rag/transform/base.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index 36f59730f..67d64b8d9 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -128,6 +128,14 @@ def impl(node: DocNode): return sum([impl(node) for node in documents], []) def forward(self, nodes: DocNode, **kwargs) -> List[DocNode]: + if type(self).transform is not NodeTransform.transform: + if not getattr(self, '_legacy_transform_compat_warned', False): + LOG.warning( + f'[{type(self).__name__}] `transform()` is deprecated. ' + 'Please implement `forward()` instead.' + ) + self._legacy_transform_compat_warned = True + return self._normalize_splits(self.transform(nodes, **kwargs)) raise NotImplementedError( 'Subclasses must implement forward() to process a single DocNode or RichDocNode' ) @@ -136,6 +144,13 @@ def forward(self, nodes: DocNode, **kwargs) -> List[DocNode]: def transform(self, node: DocNode, **kwargs) -> List[Union[str, DocNode]]: return self.forward(node, **kwargs) + def _normalize_splits(self, splits: Any) -> List[DocNode]: + if splits is None: + return [] + if not isinstance(splits, (list, tuple)): + splits = [splits] + return [s if isinstance(s, DocNode) else DocNode(text=str(s)) for s in splits if s] + def process(self, nodes: List[Any], on_match: Optional[Callable] = None, on_miss: Optional[Callable] = None) -> List[Any]: instance_match = self._on_match From 03a97e41a3df4b12fa3f9dc7e92b8bee650276c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 12 Mar 2026 14:56:01 +0800 Subject: [PATCH 14/46] update kb api --- lazyllm/tools/rag/doc_service/doc_manager.py | 26 ++++++-------------- lazyllm/tools/rag/parsing_service/server.py | 2 ++ 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index a4562b3bd..1c53f7fb9 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -1369,8 +1369,11 @@ def list_chunks(self, page: int = 1, page_size: int = 20): def health(self): return { 'status': 'ok', - 'version': 'v1-mock', - 'deps': {'sql': True}, + 'version': 'v1', + 'deps': { + 'sql': True, + 'parser': bool(getattr(self._parser_client, '_parser_url', None)), + }, } def list_kbs( @@ -1387,6 +1390,7 @@ def list_kbs( Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) query = session.query(Kb, Rel).outerjoin(Rel, Rel.kb_id == Kb.kb_id) + query = query.filter(Kb.kb_id != '__default__') if keyword: like_expr = f'%{keyword}%' query = query.filter( @@ -1448,25 +1452,11 @@ def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: algo_id: str = '__default__'): if not kb_id: raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + if self._get_kb(kb_id) is not None: + raise DocServiceError('E_STATE_CONFLICT', f'kb already exists: {kb_id}', {'kb_id': kb_id}) self._ensure_algorithm_exists(algo_id) - binding = self._get_kb_algorithm(kb_id) - if binding is not None and binding['algo_id'] != algo_id: - raise DocServiceError( - 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {binding["algo_id"]}', - {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} - ) - update_fields = set() - if display_name is not None: - update_fields.add('display_name') - if description is not None: - update_fields.add('description') - if owner_id is not None: - update_fields.add('owner_id') - if meta is not None: - update_fields.add('meta') self._ensure_kb( kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta, - update_fields=update_fields, ) self._ensure_kb_algorithm(kb_id, algo_id) return self.get_kb(kb_id) diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index 4b86f48e3..0c4a69fc9 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -396,6 +396,8 @@ def get_algo_list(self) -> None: 'algo_id': algo_dict.get('id'), 'display_name': algo_dict.get('display_name'), 'description': algo_dict.get('description'), + 'created_at': algo_dict.get('created_at'), + 'updated_at': algo_dict.get('updated_at'), }) if not data: LOG.warning('[DocumentProcessor] No algorithm registered') From aaa3f491a3a531f97762be2a520e05aff645bd2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 12 Mar 2026 15:49:49 +0800 Subject: [PATCH 15/46] fix error --- lazyllm/tools/rag/doc_node.py | 6 +- lazyllm/tools/rag/parsing_service/impl.py | 74 ++++++++++++++------- lazyllm/tools/rag/parsing_service/server.py | 60 ++++++++++++++--- lazyllm/tools/rag/parsing_service/worker.py | 31 ++++++++- 4 files changed, 132 insertions(+), 39 deletions(-) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index 3a8d42122..ff6c0f5e2 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -295,10 +295,12 @@ def copy(self, global_metadata: dict = None, metadata: dict = None) -> 'DocNode' node._copy_source = {'uid': self.uid, RAG_KB_ID: self.global_metadata.get(RAG_KB_ID), RAG_DOC_ID: self.global_metadata.get(RAG_DOC_ID)} node._uid = str(uuid.uuid4()) + node._metadata = dict(self._metadata or {}) + node._global_metadata = dict(self._global_metadata or {}) if metadata: - node.metadata.update(metadata) + node._metadata.update(metadata) if global_metadata: - node.global_metadata.update(global_metadata) + node._global_metadata.update(global_metadata) return node def with_score(self, score): diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index 1d089f3b9..cd520663d 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -93,7 +93,8 @@ def reader(self) -> DirectoryReader: def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # noqa: C901 metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None, transfer_mode: Optional[str] = None, target_kb_id: Optional[str] = None, - target_doc_ids: Optional[List[str]] = None): + target_doc_ids: Optional[List[str]] = None, + preloaded_root_nodes: Optional[Dict[str, List[DocNode]]] = None): try: if not input_files: return add_start = time.time() @@ -108,7 +109,11 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no load_start = time.time() if transfer_mode is None: - root_nodes = self._reader.load_data(input_files, metadatas, split_nodes_by_type=True) + root_nodes = ( + preloaded_root_nodes + if preloaded_root_nodes is not None + else self._reader.load_data(input_files, metadatas, split_nodes_by_type=True) + ) else: if transfer_mode not in ('cp', 'mv'): raise ValueError(f'Invalid transfer mode: {transfer_mode}') @@ -117,15 +122,18 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no f'doc_ids:{ids}, target_doc_ids:{target_doc_ids}') doc_id_map = {ids[i]: (target_doc_ids[i], metadatas[i]) for i in range(len(ids))} - root_nodes: List[DocNode] = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) - root_nodes = [ - n.copy( - global_metadata={ - RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[n.global_metadata[RAG_DOC_ID]][0] - }, - metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] - ) for n in root_nodes - ] + source_root_nodes: List[DocNode] = self._store.get_nodes(doc_ids=ids, group=LAZY_ROOT_NAME, kb_id=kb_id) + root_uid_map = {} + root_nodes = [] + for node in source_root_nodes: + copied = self._clone_node_for_transfer( + node=node, + target_kb_id=target_kb_id, + target_doc_id=doc_id_map[node.global_metadata[RAG_DOC_ID]][0], + metadata=doc_id_map[node.global_metadata[RAG_DOC_ID]][1], + ) + root_uid_map[node.uid] = copied.uid + root_nodes.append(copied) load_time = time.time() - load_start schema_futures = [] @@ -149,7 +157,6 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no self._create_nodes_recursive(v, k) else: self._store.update_nodes(root_nodes, copy=True) - root_uid_map = {n._copy_source.get('uid'): n.uid for n in root_nodes} self._copy_segments_recursive(ids=ids, kb_id=kb_id, target_kb_id=target_kb_id, doc_id_map=doc_id_map, p_uid_map=root_uid_map, p_name=LAZY_ROOT_NAME) @@ -201,6 +208,21 @@ def _create_nodes_recursive(self, p_nodes: List[DocNode], p_name: str): nodes = self._create_nodes_impl(p_nodes, group_name, ref_path=ref_path) if nodes: self._create_nodes_recursive(nodes, group_name) + def _clone_node_for_transfer( + self, node: DocNode, target_kb_id: str, target_doc_id: str, metadata: Dict[str, Any] + ) -> DocNode: + copied = node.copy( + global_metadata={RAG_KB_ID: target_kb_id, RAG_DOC_ID: target_doc_id}, + metadata=metadata, + ) + copied._global_metadata = { + **(node.global_metadata or {}), + RAG_KB_ID: target_kb_id, + RAG_DOC_ID: target_doc_id, + } + copied._metadata = {**(node.metadata or {}), **(metadata or {})} + return copied + def _copy_segments_recursive(self, ids: List[str], kb_id: str, target_kb_id: str, doc_id_map: Dict[str, tuple], p_uid_map: dict, p_name: str): for group_name in self._store.activated_groups(): @@ -209,19 +231,19 @@ def _copy_segments_recursive(self, ids: List[str], kb_id: str, target_kb_id: str raise ValueError(f'Node group {group_name} does not exist. Please check the group name ' 'or add a new one through `create_node_group`.') if group['parent'] == p_name: - nodes = self._store.get_nodes(doc_ids=ids, group=group_name, kb_id=kb_id) - nodes = [ - n.copy( - global_metadata={ - RAG_KB_ID: target_kb_id, RAG_DOC_ID: doc_id_map[n.global_metadata[RAG_DOC_ID]][0] - }, - metadata=doc_id_map[n.global_metadata[RAG_DOC_ID]][1] - ) for n in nodes - ] + source_nodes = self._store.get_nodes(doc_ids=ids, group=group_name, kb_id=kb_id) + nodes = [] uid_map = {} - for n in nodes: - uid_map[n._copy_source.get('uid')] = n.uid - n.parent = p_uid_map.get(n.parent, None) if n.parent else None + for source_node in source_nodes: + copied = self._clone_node_for_transfer( + node=source_node, + target_kb_id=target_kb_id, + target_doc_id=doc_id_map[source_node.global_metadata[RAG_DOC_ID]][0], + metadata=doc_id_map[source_node.global_metadata[RAG_DOC_ID]][1], + ) + uid_map[source_node.uid] = copied.uid + copied.parent = p_uid_map.get(source_node.parent, None) if source_node.parent else None + nodes.append(copied) self._store.update_nodes(nodes, copy=True) if nodes: self._copy_segments_recursive(ids=ids, kb_id=kb_id, target_kb_id=target_kb_id, @@ -257,6 +279,7 @@ def _reparse_docs(self, group_name: str, doc_ids: List[str], doc_paths: List[str raise ValueError('metadatas is required for reparse') kb_id = metadatas[0].get(RAG_KB_ID, None) if kb_id is None else kb_id if group_name == 'all': + preloaded_root_nodes = self._reader.load_data(doc_paths, metadatas, split_nodes_by_type=True) self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id) removed_flag = False for wait_time in fibonacci_backoff(): @@ -267,7 +290,8 @@ def _reparse_docs(self, group_name: str, doc_ids: List[str], doc_paths: List[str time.sleep(wait_time) if not removed_flag: raise Exception(f'Failed to remove nodes for docs {doc_ids} from store') - self.add_doc(input_files=doc_paths, ids=doc_ids, metadatas=metadatas, kb_id=kb_id) + self.add_doc(input_files=doc_paths, ids=doc_ids, metadatas=metadatas, kb_id=kb_id, + preloaded_root_nodes=preloaded_root_nodes) LOG.info(f'Reparse docs {doc_ids} from store done') else: p_nodes = self._store.get_nodes(group=self._node_groups[group_name]['parent'], diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index b438e3414..f3df99970 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -1,4 +1,5 @@ import json +import inspect import threading import time import traceback @@ -132,6 +133,7 @@ def process_finished_task(self): if finished_task: self._callback( task_id=finished_task.get('task_id'), + task_type=finished_task.get('task_type'), task_status=finished_task.get('task_status'), error_code=finished_task.get('error_code'), error_msg=finished_task.get('error_msg') @@ -496,26 +498,64 @@ def _check_post_func(self) -> bool: if not callable(self._post_func): LOG.error('[DocumentProcessor] Post function is not callable') return False - if not all( - param in self._post_func.__code__.co_varnames for param in [ - 'task_id', 'task_status', 'error_code', 'error_msg' - ] + try: + sig = inspect.signature(self._post_func) + except (TypeError, ValueError): + LOG.error('[DocumentProcessor] Failed to inspect post function signature') + return False + params = sig.parameters + has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + if not has_var_keyword and not all( + param in params for param in ['task_id', 'task_status', 'error_code', 'error_msg'] ): LOG.error('[DocumentProcessor] Post function params do not include' ' task_id, task_status, error_code, error_msg') return False return True - def _callback(self, task_id: str, task_status: str = None, error_code: str = None, error_msg: str = None): + def _callback(self, task_id: str, task_type: str = None, + task_status: str = None, error_code: str = None, error_msg: str = None): '''callback to service''' - message = f'Task {task_id} finished with status: {task_status}.' - if error_msg: - message += f' Error code: {error_code}, error_msg: {error_msg}.' - LOG.info(f'[DocumentProcessor] {message}') + parts = [f'task_id={task_id}'] + if task_type: + parts.append(f'task_type={task_type}') + if task_status: + parts.append(f'status={task_status}') + has_error = ( + task_status not in (TaskStatus.FINISHED.value, None) + or error_code not in (None, '', '200') + or error_msg not in (None, '', 'success') + ) + if has_error: + if error_code not in (None, ''): + parts.append(f'error_code={error_code}') + if error_msg not in (None, ''): + parts.append(f'error_msg={error_msg}') + message = '[DocumentProcessor] Task callback: ' + ', '.join(parts) + if task_status == TaskStatus.FAILED.value: + LOG.error(message) + elif task_status in (TaskStatus.CANCELED.value, TaskStatus.CANCEL_REQUESTED.value): + LOG.warning(message) + else: + LOG.info(message) if self._post_func: try: - self._post_func(task_id, task_status, error_code, error_msg) + params = inspect.signature(self._post_func).parameters + accepts_task_type = ( + 'task_type' in params + or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) + ) + if accepts_task_type: + self._post_func( + task_id=task_id, + task_type=task_type, + task_status=task_status, + error_code=error_code, + error_msg=error_msg, + ) + else: + self._post_func(task_id, task_status, error_code, error_msg) except Exception as e: LOG.error(f'[DocumentProcessor] Failed to call post function: {e}, {traceback.format_exc()}') raise e diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 72afee58d..5a3358830 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -317,6 +317,29 @@ def _validate_task_payload(self, task_type: str, payload: dict): if not isinstance(doc_ids, list) or not doc_ids: raise ValueError('doc_ids is required for task_type DOC_DELETE') + def _summarize_task_payload(self, task_type: str, payload: dict) -> str: + summary = { + 'task_type': task_type, + 'algo_id': payload.get('algo_id'), + 'kb_id': payload.get('kb_id'), + } + if task_type == TaskType.DOC_DELETE.value: + summary['doc_ids'] = payload.get('doc_ids', []) + else: + file_infos = [] + for file_info in payload.get('file_infos', []): + transfer_params = file_info.get('transfer_params') or {} + file_infos.append({ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'reparse_group': file_info.get('reparse_group'), + 'target_doc_id': transfer_params.get('target_doc_id'), + 'target_kb_id': transfer_params.get('target_kb_id'), + 'transfer_mode': transfer_params.get('mode'), + }) + summary['file_infos'] = file_infos + return json.dumps(summary, ensure_ascii=False) + def _enqueue_task_from_payload(self, task: dict): try: task_type = task.get('task_type') @@ -377,8 +400,8 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo if not algo_id: raise ValueError(f'{self._log_prefix(task_id)} task_id is missing algo_id in payload: {payload}') - LOG.info(f'{self._log_prefix(task_id)} Start processing task, type: {task_type}, ' - f'algo_id: {algo_id}') + LOG.info(f'{self._log_prefix(task_id)} Start processing task: ' + f'{self._summarize_task_payload(task_type, payload)}') processor = self._get_or_create_processor(algo_id) if task_type == TaskType.DOC_ADD.value: @@ -501,6 +524,8 @@ def _worker_impl(self): # noqa: C901 f'{traceback.format_exc()}') time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue + LOG.info(f'{self._log_prefix(task_data["task_id"])} [Worker] Claimed queued task: ' + f'{self._summarize_task_payload(task_data["task_type"], payload)}') self._run_task(task_data['task_id'], task_data['task_type'], payload, from_queue=True) continue @@ -517,6 +542,8 @@ def _worker_impl(self): # noqa: C901 LOG.warning(f'{self._log_prefix()} [Poller] Skip invalid task payload: {e}. ' f'payload={task}') continue + LOG.info(f'{self._log_prefix(task_id)} [Poller] Received direct task: ' + f'{self._summarize_task_payload(task_type, payload)}') self._run_task(task_id, task_type, payload, from_queue=False) except Exception as e: LOG.error(f'{self._log_prefix()} [Poller] fetch failed: {e}') From ee4ba13d83693b39f14bbec3cbe48b6fc2313dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 12 Mar 2026 19:49:50 +0800 Subject: [PATCH 16/46] fix delete params --- lazyllm/tools/rag/parsing_service/base.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 457d2b73e..b197da80c 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, BeforeValidator +from pydantic import BaseModel, Field, BeforeValidator, model_validator from typing import Dict, List, Optional, Any, Annotated from enum import Enum from uuid import uuid4 @@ -66,6 +66,14 @@ class DeleteDocRequest(BaseModel): # NOTE: (db_info) is deprecated, will be removed in the future db_info: EmptyDBInfo = None + @model_validator(mode='before') + @classmethod + def _compat_dataset_id(cls, data): + if isinstance(data, dict) and not data.get('kb_id') and data.get('dataset_id'): + data = dict(data) + data['kb_id'] = data['dataset_id'] + return data + class CancelTaskRequest(BaseModel): task_id: str From 3c02f74c44f77290283d5b680dce1c666619d3f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Fri, 13 Mar 2026 19:11:48 +0800 Subject: [PATCH 17/46] modify api --- lazyllm/tools/rag/doc_service/doc_manager.py | 301 +++++++++++++++---- lazyllm/tools/rag/parsing_service/base.py | 30 +- lazyllm/tools/rag/parsing_service/server.py | 85 ++++-- lazyllm/tools/rag/parsing_service/worker.py | 47 ++- 4 files changed, 369 insertions(+), 94 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 1c53f7fb9..96fb654ff 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -544,6 +544,12 @@ def _refresh_kb_doc_count(self, kb_id: str): kb_row.updated_at = now_ts() session.add(kb_row) + def _list_kb_doc_ids(self, kb_id: str) -> List[str]: + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + rows = session.query(Rel.doc_id).filter(Rel.kb_id == kb_id).all() + return [row[0] for row in rows] + def _has_kb_document(self, kb_id: str, doc_id: str): with self._db_manager.get_session() as session: Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) @@ -635,14 +641,106 @@ def _get_latest_parse_snapshot(self, doc_id: str, kb_id: str): ) return _orm_to_dict(row) if row else None + def _delete_parse_snapshots(self, doc_id: str, kb_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + session.query(State).filter(State.doc_id == doc_id, State.kb_id == kb_id).delete() + + def _delete_doc_if_orphaned(self, doc_id: str) -> bool: + if self._doc_relation_count(doc_id) > 0: + return False + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return False + session.delete(row) + return True + + def _purge_deleted_kb_doc_data(self, kb_id: str, doc_id: str, remove_relation: bool = False): + if remove_relation: + self._remove_kb_document(kb_id, doc_id) + self._delete_parse_snapshots(doc_id, kb_id) + if not self._delete_doc_if_orphaned(doc_id): + self._sync_doc_upload_status(doc_id) + + def _mark_task_cleanup_policy(self, task_id: str, cleanup_policy: str): + task = self._get_task_record(task_id) + if task is None: + return + message = task.get('message') or {} + if message.get('cleanup_policy') == cleanup_policy: + return + message['cleanup_policy'] = cleanup_policy + self._update_task_record(task_id, message=_to_json(message)) + + def _finalize_kb_deletion_if_empty(self, kb_id: str) -> bool: + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + if session.query(Rel).filter(Rel.kb_id == kb_id).count() > 0: + return False + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + AlgoRel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is not None: + session.delete(kb_row) + session.query(AlgoRel).filter(AlgoRel.kb_id == kb_id).delete() + return True + + def _prepare_kb_delete_items(self, kb_id: str) -> Dict[str, Any]: + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + binding = self._get_kb_algorithm(kb_id) + default_algo_id = binding['algo_id'] if binding is not None else '__default__' + items = [] + for doc_id in self._list_kb_doc_ids(kb_id): + snapshot = self._get_parse_snapshot(doc_id, kb_id, default_algo_id) or self._get_latest_parse_snapshot(doc_id, kb_id) + if snapshot is None or snapshot.get('status') == DocStatus.DELETED.value: + items.append({'action': 'purge_local', 'doc_id': doc_id}) + continue + status = snapshot.get('status') + task_type = snapshot.get('task_type') + if status in (DocStatus.WAITING.value, DocStatus.WORKING.value): + if task_type == TaskType.DOC_DELETE.value and snapshot.get('current_task_id'): + items.append({ + 'action': 'reuse_delete_task', + 'doc_id': doc_id, + 'task_id': snapshot['current_task_id'], + }) + continue + raise DocServiceError( + 'E_STATE_CONFLICT', + f'cannot delete kb while doc {doc_id} task is {status}', + {'kb_id': kb_id, 'doc_id': doc_id, 'status': status, 'task_type': task_type}, + ) + items.append({ + 'action': 'enqueue_delete', + 'doc_id': doc_id, + 'algo_id': snapshot.get('algo_id') or default_algo_id, + }) + return {'kb': kb, 'items': items} + def _assert_action_allowed(self, doc_id: str, kb_id: str, algo_id: str, action: str): snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) - if snapshot is None: + status = snapshot.get('status') if snapshot is not None else None + if status is None and action in ('add', 'upload'): + doc = self._get_doc(doc_id) + status = doc.get('upload_status') if doc else None + + if action in ('add', 'upload'): + if status in ( + DocStatus.WAITING.value, + DocStatus.WORKING.value, + DocStatus.DELETING.value, + DocStatus.SUCCESS.value, + ): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is {status}') return - status = snapshot.get('status') - if status == DocStatus.WORKING.value and action in ('upload', 'reparse', 'delete', 'transfer', 'metadata'): + + if status == DocStatus.WORKING.value and action in ('reparse', 'delete', 'transfer', 'metadata'): raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is WORKING') - if status == DocStatus.DELETING.value and action in ('upload', 'reparse', 'delete', 'transfer', 'metadata'): + if status == DocStatus.DELETING.value and action in ('reparse', 'delete', 'transfer', 'metadata'): raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is DELETING') def _upsert_parse_snapshot( @@ -740,10 +838,14 @@ def _call_parser_client(self, method, *args, **kwargs): try: return method(*args, **kwargs) except TypeError as exc: - if 'callback_url' not in kwargs or 'callback_url' not in str(exc): - raise compat_kwargs = dict(kwargs) - compat_kwargs.pop('callback_url', None) + removed = False + for field in ('callback_url',): + if field in compat_kwargs and field in str(exc): + compat_kwargs.pop(field, None) + removed = True + if not removed: + raise return method(*args, **compat_kwargs) def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, @@ -784,7 +886,7 @@ def _enqueue_task( self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, idempotency_key: Optional[str] = None, priority: int = 0, file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - reparse_group: Optional[str] = None, + reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, ): task_id = str(uuid4()) task_message = { @@ -795,6 +897,8 @@ def _enqueue_task( 'metadata': metadata, 'reparse_group': reparse_group, } + if cleanup_policy: + task_message['cleanup_policy'] = cleanup_policy task_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING self._create_task_record(task_id, task_type, doc_id, kb_id, algo_id, task_status, message=task_message) parse_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING @@ -886,6 +990,59 @@ def _prepare_upload_items(self, request: UploadRequest) -> List[Dict[str, Any]]: self._assert_action_allowed(item['doc_id'], request.kb_id, request.algo_id, 'upload') return prepared_items + def _prepare_reparse_items(self, request: ReparseRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for doc_id in request.doc_ids: + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'reparse') + prepared_items.append({ + 'doc_id': doc_id, + 'file_path': doc.get('path'), + 'metadata': _from_json(doc.get('meta')), + }) + return prepared_items + + def _prepare_delete_items(self, request: DeleteRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for doc_id in request.doc_ids: + doc = self._get_doc(doc_id) + snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) or self._get_latest_parse_snapshot(doc_id, request.kb_id) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): + if snapshot is not None and snapshot.get('status') in (DocStatus.DELETING.value, DocStatus.DELETED.value): + prepared_items.append({ + 'doc_id': doc_id, + 'action': 'noop', + 'status': snapshot['status'], + 'task_id': snapshot.get('current_task_id'), + }) + continue + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + if snapshot is not None and snapshot.get('status') in (DocStatus.DELETING.value, DocStatus.DELETED.value): + prepared_items.append({ + 'doc_id': doc_id, + 'action': 'noop', + 'status': snapshot['status'], + 'task_id': snapshot.get('current_task_id'), + }) + continue + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'delete') + prepared_items.append({'doc_id': doc_id, 'action': 'execute'}) + return prepared_items + + def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for item in request.items: + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._assert_action_allowed(item.doc_id, request.kb_id, request.algo_id, 'metadata') + merged = _from_json(doc.get('meta')) + merged.update(item.patch) + prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) + return prepared_items + def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: self._validate_kb_algorithm(request.kb_id, request.algo_id) prepared_items = self._prepare_upload_items(request) @@ -945,17 +1102,14 @@ def add_files(self, request: AddRequest) -> List[Dict[str, Any]]: def reparse(self, request: ReparseRequest) -> List[str]: self._validate_kb_algorithm(request.kb_id, request.algo_id) self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') + prepared_items = self._prepare_reparse_items(request) task_ids = [] - for doc_id in request.doc_ids: - doc = self._get_doc(doc_id) - if doc is None or not self._has_kb_document(request.kb_id, doc_id): - raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') - self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'reparse') + for item in prepared_items: task_id, _ = self._enqueue_task( - doc_id, request.kb_id, request.algo_id, TaskType.DOC_REPARSE, + item['doc_id'], request.kb_id, request.algo_id, TaskType.DOC_REPARSE, idempotency_key=request.idempotency_key, - file_path=doc.get('path'), - metadata=_from_json(doc.get('meta')), + file_path=item['file_path'], + metadata=item['metadata'], reparse_group='all', ) task_ids.append(task_id) @@ -964,12 +1118,41 @@ def reparse(self, request: ReparseRequest) -> List[str]: def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: self._validate_kb_algorithm(request.kb_id, request.algo_id) self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') + prepared_items = self._prepare_delete_items(request) items: List[Dict[str, Any]] = [] - for doc_id in request.doc_ids: - doc = self._get_doc(doc_id) - if doc is None or not self._has_kb_document(request.kb_id, doc_id): - raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') - self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'delete') + for item in prepared_items: + doc_id = item['doc_id'] + if item.get('action') == 'noop': + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': item.get('task_id'), + 'status': item['status'], + 'error_code': None, + }) + continue + snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) + if ( + snapshot is not None + and snapshot.get('status') == DocStatus.WAITING.value + and snapshot.get('task_type') == TaskType.DOC_ADD.value + and snapshot.get('current_task_id') + ): + cancel_resp = self.cancel_task(snapshot['current_task_id']) + if cancel_resp.code != 200: + raise DocServiceError( + 'E_STATE_CONFLICT', + cancel_resp.msg, + cancel_resp.data if isinstance(cancel_resp.data, dict) else {'task_id': snapshot['current_task_id']}, + ) + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': snapshot['current_task_id'], + 'status': DocStatus.CANCELED.value, + 'error_code': None, + }) + continue with self._db_manager.get_session() as session: Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) row = session.query(Doc).filter(Doc.doc_id == doc_id).first() @@ -1030,27 +1213,22 @@ def patch_metadata(self, request: MetadataPatchRequest): updated = [] failed = [] self._validate_unique_doc_ids([item.doc_id for item in request.items], field_name='doc_id') - for item in request.items: - doc = self._get_doc(item.doc_id) - if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): - raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') - self._assert_action_allowed(item.doc_id, request.kb_id, request.algo_id, 'metadata') - merged = _from_json(doc.get('meta')) - merged.update(item.patch) + prepared_items = self._prepare_metadata_patch_items(request) + for item in prepared_items: with self._db_manager.get_session() as session: Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) - row = session.query(Doc).filter(Doc.doc_id == item.doc_id).first() - row.meta = _to_json(merged) + row = session.query(Doc).filter(Doc.doc_id == item['doc_id']).first() + row.meta = _to_json(item['metadata']) row.updated_at = now_ts() session.add(row) task_id, _ = self._enqueue_task( - item.doc_id, request.kb_id, request.algo_id, TaskType.DOC_UPDATE_META, + item['doc_id'], request.kb_id, request.algo_id, TaskType.DOC_UPDATE_META, idempotency_key=request.idempotency_key, - file_path=doc.get('path'), - metadata=merged, + file_path=item['file_path'], + metadata=item['metadata'], ) - updated.append({'doc_id': item.doc_id, 'task_id': task_id}) + updated.append({'doc_id': item['doc_id'], 'task_id': task_id}) return { 'updated_count': len(updated), 'doc_ids': [u['doc_id'] for u in updated], @@ -1114,6 +1292,8 @@ def on_task_callback(self, callback: TaskCallbackRequest): snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) if snapshot and snapshot.get('current_task_id') and snapshot['current_task_id'] != callback.task_id: return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} + if task_data.get('status') == DocStatus.CANCELED.value and callback.status != DocStatus.CANCELED: + return {'ack': True, 'deduped': False, 'ignored_reason': 'canceled_task_callback'} if callback.event_type == CallbackEventType.START: self._update_task_record( @@ -1153,6 +1333,8 @@ def on_task_callback(self, callback: TaskCallbackRequest): failed_stage = None if final_status == DocStatus.FAILED: failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' + task_message = task_data.get('message') or {} + cleanup_policy = task_message.get('cleanup_policy') self._update_task_record( callback.task_id, @@ -1181,6 +1363,9 @@ def on_task_callback(self, callback: TaskCallbackRequest): if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: self._remove_kb_document(kb_id, doc_id) self._apply_doc_upload_status(doc_id, task_type, final_status) + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED and cleanup_policy == 'purge': + self._purge_deleted_kb_doc_data(kb_id, doc_id) + self._finalize_kb_deletion_if_empty(kb_id) return {'ack': True, 'deduped': False, 'ignored_reason': None} @@ -1492,30 +1677,34 @@ def update_kb(self, kb_id: str, display_name: Optional[str] = None, description: def delete_kb(self, kb_id: str): if not kb_id: raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') - with self._db_manager.get_session() as session: - Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) - Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) - Snap = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) - kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() - if kb_row is None: - raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) - states = ( - session.query(Snap) - .join(Rel, sqlalchemy.and_(Snap.doc_id == Rel.doc_id, Snap.kb_id == Rel.kb_id)) - .filter(Rel.kb_id == kb_id, ~Snap.status.in_([DocStatus.DELETED.value, DocStatus.CANCELED.value])) - .all() - ) + prepared = self._prepare_kb_delete_items(kb_id) task_ids = [] - for row in states: - task_id, _ = self._enqueue_task(row.doc_id, row.kb_id, row.algo_id, TaskType.DOC_DELETE) - task_ids.append(task_id) + for item in prepared['items']: + if item['action'] == 'purge_local': + self._purge_deleted_kb_doc_data(kb_id, item['doc_id'], remove_relation=True) + continue + if item['action'] == 'reuse_delete_task': + self._mark_task_cleanup_policy(item['task_id'], 'purge') + task_ids.append(item['task_id']) + continue + if item['action'] == 'enqueue_delete': + task_id, _ = self._enqueue_task( + item['doc_id'], kb_id, item['algo_id'], TaskType.DOC_DELETE, cleanup_policy='purge' + ) + task_ids.append(task_id) + continue + raise RuntimeError(f'unsupported kb delete action: {item["action"]}') new_status = KBStatus.DELETING.value if task_ids else KBStatus.DELETED.value - with self._db_manager.get_session() as session: - Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) - kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() - kb_row.status = new_status - kb_row.updated_at = now_ts() - session.add(kb_row) + if task_ids: + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is not None: + kb_row.status = new_status + kb_row.updated_at = now_ts() + session.add(kb_row) + else: + self._finalize_kb_deletion_if_empty(kb_id) return {'kb_id': kb_id, 'status': new_status, 'task_ids': task_ids} def delete_kbs(self, kb_ids: List[str]): diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index f7ff8390d..55641a150 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from typing import Dict, List, Optional, Any from enum import Enum from uuid import uuid4 @@ -34,6 +34,14 @@ class AddDocRequest(BaseModel): db_info: Optional[DBInfo] = None feedback_url: Optional[str] = None + @model_validator(mode='before') + @classmethod + def normalize_deprecated_fields(cls, data): + if isinstance(data, dict) and not data.get('db_info'): + data = dict(data) + data['db_info'] = None + return data + class UpdateMetaRequest(BaseModel): task_id: str = Field(default_factory=lambda: str(uuid4())) @@ -46,6 +54,14 @@ class UpdateMetaRequest(BaseModel): db_info: Optional[DBInfo] = None feedback_url: Optional[str] = None + @model_validator(mode='before') + @classmethod + def normalize_deprecated_fields(cls, data): + if isinstance(data, dict) and not data.get('db_info'): + data = dict(data) + data['db_info'] = None + return data + class DeleteDocRequest(BaseModel): task_id: str = Field(default_factory=lambda: str(uuid4())) @@ -58,6 +74,18 @@ class DeleteDocRequest(BaseModel): db_info: Optional[DBInfo] = None feedback_url: Optional[str] = None + @model_validator(mode='before') + @classmethod + def normalize_legacy_fields(cls, data): + if not isinstance(data, dict): + return data + data = dict(data) + if not data.get('kb_id') and data.get('dataset_id'): + data['kb_id'] = data['dataset_id'] + if not data.get('db_info'): + data['db_info'] = None + return data + class CancelTaskRequest(BaseModel): task_id: str diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index 0c4a69fc9..fbe70cbb2 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -409,26 +409,55 @@ def get_algo_group_info(self, algo_id: str) -> None: if self._shutdown: raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') try: - algorithm = self._get_algo(algo_id) - if algorithm is None: - raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') - info_pickle_bytes = algorithm.get('info_pickle') - info = cloudpickle.loads(info_pickle_bytes) - store: _DocumentStore = info['store'] # type: ignore - node_groups = info['node_groups'] - - data = [] - for group_name in store.activated_groups(): - if group_name in node_groups: - group_info = {'name': group_name, 'type': node_groups[group_name].get('group_type'), - 'display_name': node_groups[group_name].get('display_name')} - data.append(group_info) + data = self._get_algo_group_info_data(algo_id) LOG.info(f'[DocumentProcessor] Get group info for {algo_id} success with {data}') return BaseResponse(code=200, msg='success', data=data) + except fastapi.HTTPException: + raise except Exception as e: LOG.error(f'[DocumentProcessor] Failed to get group info: {e}, {traceback.format_exc()}') raise fastapi.HTTPException(status_code=500, detail=f'Failed to get group info: {str(e)}') + def get_group_info(self, algo_id: str) -> None: + return self.get_algo_group_info(algo_id) + + def _get_algo_group_info_data(self, algo_id: str): + algorithm = self._get_algo(algo_id) + if algorithm is None: + raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + info_pickle_bytes = algorithm.get('info_pickle') + info = cloudpickle.loads(info_pickle_bytes) + store: _DocumentStore = info['store'] # type: ignore + node_groups = info['node_groups'] + + data = [] + for group_name in store.activated_groups(): + if group_name in node_groups: + group_info = {'name': group_name, 'type': node_groups[group_name].get('group_type'), + 'display_name': node_groups[group_name].get('display_name')} + data.append(group_info) + return data + + @staticmethod + def _resolve_add_task_type(file_infos) -> str: + has_reparse = False + has_new_file = False + for file_info in file_infos: + if file_info.reparse_group is not None: + has_reparse = True + else: + has_new_file = True + if has_new_file and has_reparse: + raise fastapi.HTTPException( + status_code=400, + detail='new_file_ids and reparse_file_ids cannot be specified at the same time' + ) + if has_reparse: + return TaskType.DOC_REPARSE.value + if has_new_file: + return TaskType.DOC_ADD.value + raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + @app.post('/doc/add') def add_doc(self, request: AddDocRequest): self._lazy_init() @@ -445,26 +474,10 @@ def add_doc(self, request: AddDocRequest): if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') # NOTE: No idempotency key check, should be handled by the caller! - new_file_ids = [] - reparse_file_ids = [] for file_info in file_infos: if self._path_prefix: file_info.file_path = create_file_path(path=file_info.file_path, prefix=self._path_prefix) - if file_info.reparse_group is not None: - reparse_file_ids.append(file_info.doc_id) - else: - new_file_ids.append(file_info.doc_id) - if new_file_ids and reparse_file_ids: - raise fastapi.HTTPException( - status_code=400, - detail='new_file_ids and reparse_file_ids cannot be specified at the same time' - ) - if new_file_ids: - task_type = TaskType.DOC_ADD.value - elif reparse_file_ids: - task_type = TaskType.DOC_REPARSE.value - else: - raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + task_type = self._resolve_add_task_type(file_infos) payload = request.model_dump() resolved_callback_url = self._resolve_callback_url(payload) if resolved_callback_url: @@ -638,9 +651,12 @@ def _check_post_func(self) -> bool: return False return True - def _callback(self, finished_task: Dict[str, Any]): + def _callback(self, finished_task: Optional[Dict[str, Any]] = None, **legacy_kwargs): '''callback to service''' + if finished_task is None: + finished_task = legacy_kwargs task_id = finished_task.get('task_id') + task_type = finished_task.get('task_type') task_status = finished_task.get('task_status') error_code = finished_task.get('error_code') error_msg = finished_task.get('error_msg') @@ -651,7 +667,10 @@ def _callback(self, finished_task: Dict[str, Any]): try: if self._post_func: - self._post_func(task_id, task_status, error_code, error_msg) + if 'task_type' in self._post_func.__code__.co_varnames: + self._post_func(task_id, task_status, error_code, error_msg, task_type=task_type) + else: + self._post_func(task_id, task_status, error_code, error_msg) else: self._default_post_func(finished_task) except Exception as e: diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 7ef7c7550..76c3b6ae9 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -9,6 +9,7 @@ from ..utils import BaseResponse, _get_default_db_config from .base import ( FINISHED_TASK_QUEUE_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, + AddDocRequest, UpdateMetaRequest, DeleteDocRequest, TaskStatus, TaskType, ALGORITHM_TABLE_INFO ) from .impl import _Processor @@ -196,6 +197,40 @@ def _build_task_context(task_type: str, payload: dict) -> dict: 'items': items, } + @staticmethod + def _infer_task_type(task_type: str, payload: dict) -> str: + if task_type == TaskType.DOC_DELETE.value: + return task_type + if task_type == TaskType.DOC_UPDATE_META.value: + return task_type + + file_infos = payload.get('file_infos') or [] + has_reparse = any(file_info.get('reparse_group') is not None for file_info in file_infos) + if has_reparse: + return TaskType.DOC_REPARSE.value + return task_type + + def _parse_task_payload(self, task_data: dict): + task_id = task_data.get('task_id') + task_type = task_data.get('task_type') + if not task_id: + raise ValueError('task_id is required') + if not task_type: + raise ValueError('task_type is required') + + if task_type == TaskType.DOC_DELETE.value: + payload = DeleteDocRequest.model_validate(task_data).model_dump(mode='json') + elif task_type == TaskType.DOC_UPDATE_META.value: + payload = UpdateMetaRequest.model_validate(task_data).model_dump(mode='json') + else: + payload = AddDocRequest.model_validate(task_data).model_dump(mode='json') + + task_type = self._infer_task_type(task_type, payload) + return task_id, task_type, payload + + def _summarize_task_payload(self, task_type: str, payload: dict) -> str: + return json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False, sort_keys=True) + def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: TaskStatus, error_code: str = None, error_msg: str = None, callback_url: str = None, task_context_json: str = None): @@ -231,18 +266,22 @@ def _worker_impl(self): time.sleep(0.1) continue - task_id = task_data['task_id'] - task_type = task_data['task_type'] - payload = json.loads(task_data.get('message')) + raw_payload = json.loads(task_data.get('message')) + task_id, task_type, payload = self._parse_task_payload({ + 'task_id': task_data.get('task_id'), + 'task_type': task_data.get('task_type'), + **raw_payload, + }) algo_id = payload.get('algo_id') if not algo_id: raise ValueError(f'[DocumentProcessorWorker._Impl] task_id {task_id} is missing algo_id in ' f'payload: {payload}') callback_url = self._resolve_callback_url(payload) task_context_json = json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False) + task_summary = self._summarize_task_payload(task_type, payload) LOG.info(f'[DocumentProcessorWorker._Impl] Start processing task {task_id}, type: {task_type},' - f' algo_id: {algo_id}') + f' algo_id: {algo_id}, payload={task_summary}') self._enqueue_finished_task( task_id=task_id, task_type=task_type, From c81dafdb0b768f1fad0a63629cd21c97bac512c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 10:55:04 +0800 Subject: [PATCH 18/46] temp doc server example --- examples/rag/doc_service_standalone.py | 219 +------------------------ 1 file changed, 1 insertion(+), 218 deletions(-) diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py index 99e1e38e0..1c3e58095 100644 --- a/examples/rag/doc_service_standalone.py +++ b/examples/rag/doc_service_standalone.py @@ -14,24 +14,17 @@ from __future__ import annotations import argparse -import json import os import tempfile import threading import time -from datetime import datetime -from typing import Any, Dict, Optional -from uuid import uuid4 +from typing import Any, Dict import requests -import lazyllm from lazyllm import Document from lazyllm.tools.rag.doc_service import DocServer -from lazyllm.tools.rag.doc_service.base import CallbackEventType, DocStatus, TaskCallbackRequest, TaskCreateRequest from lazyllm.tools.rag.parsing_service import DocumentProcessor -from lazyllm.tools.rag.parsing_service.base import TaskStatus, TaskType -from lazyllm.tools.rag.utils import BaseResponse REAL_ALGO_ID = 'real-standalone-algo' @@ -71,203 +64,6 @@ def _poll(): return _wait_until(_poll, timeout=timeout) -def _task_status_to_doc_status(task_status: str) -> DocStatus: - mapping = { - TaskStatus.SUCCESS.value: DocStatus.SUCCESS, - TaskStatus.FAILED.value: DocStatus.FAILED, - TaskStatus.CANCELED.value: DocStatus.CANCELED, - } - if task_status not in mapping: - raise RuntimeError(f'unsupported task status: {task_status}') - return mapping[task_status] - - -class _RealProcessorTaskAdapter: - def __init__(self, parser_base_url: str, manager, upstream_client): - self._parser_base_url = parser_base_url.rstrip('/') - self._manager = manager - self._upstream_client = upstream_client - self._tasks: Dict[str, Dict[str, Any]] = {} - self._lock = threading.Lock() - - def _load_doc(self, doc_id: str) -> Dict[str, Any]: - doc = self._manager._get_doc(doc_id) - if doc is None: - raise RuntimeError(f'doc not found: {doc_id}') - return doc - - @staticmethod - def _load_metadata(doc: Dict[str, Any]) -> Dict[str, Any]: - raw = doc.get('meta') - return json.loads(raw) if raw else {} - - def _record_task(self, req: TaskCreateRequest) -> Dict[str, Any]: - now = datetime.now().isoformat() - task = { - 'task_id': req.task_id, - 'task_type': req.task_type.value, - 'doc_id': req.doc_id, - 'kb_id': req.kb_id, - 'algo_id': req.algo_id, - 'status': TaskStatus.WAITING.value, - 'priority': req.priority, - 'callback_url': req.callback_url, - 'error_code': None, - 'error_msg': None, - 'created_at': now, - 'updated_at': now, - 'started_at': None, - 'finished_at': None, - } - with self._lock: - self._tasks[req.task_id] = task - return task - - def _dispatch_add_like_task(self, req: TaskCreateRequest, doc: Dict[str, Any], reparse: bool = False): - file_info = { - 'file_path': doc['path'], - 'doc_id': req.doc_id, - 'metadata': self._load_metadata(doc), - } - if reparse: - file_info['reparse_group'] = 'CoarseChunk' - payload = { - 'task_id': req.task_id, - 'algo_id': req.algo_id, - 'kb_id': req.kb_id, - 'file_infos': [file_info], - 'priority': req.priority, - } - return requests.post(f'{self._parser_base_url}/doc/add', json=payload, timeout=15) - - def create_task(self, req: TaskCreateRequest): - self._record_task(req) - try: - if req.task_type == TaskType.DOC_ADD: - resp = self._dispatch_add_like_task(req, self._load_doc(req.doc_id)) - elif req.task_type == TaskType.DOC_REPARSE: - resp = self._dispatch_add_like_task(req, self._load_doc(req.doc_id), reparse=True) - elif req.task_type == TaskType.DOC_DELETE: - resp = requests.delete( - f'{self._parser_base_url}/doc/delete', - json={ - 'task_id': req.task_id, - 'algo_id': req.algo_id, - 'kb_id': req.kb_id, - 'doc_ids': [req.doc_id], - 'priority': req.priority, - }, - timeout=15, - ) - elif req.task_type == TaskType.DOC_UPDATE_META: - doc = self._load_doc(req.doc_id) - resp = requests.post( - f'{self._parser_base_url}/doc/meta/update', - json={ - 'task_id': req.task_id, - 'algo_id': req.algo_id, - 'kb_id': req.kb_id, - 'file_infos': [{ - 'file_path': doc['path'], - 'doc_id': req.doc_id, - 'metadata': req.metadata, - }], - 'priority': req.priority, - }, - timeout=15, - ) - else: - raise RuntimeError(f'unsupported task type: {req.task_type.value}') - if resp.status_code >= 400: - raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') - result = BaseResponse.model_validate(resp.json()) - if result.code != 200: - raise RuntimeError(f'parser task rejected: {result.msg}') - return result - except Exception: - with self._lock: - self._tasks.pop(req.task_id, None) - raise - - def mark_task_finished(self, task_id: str, task_status: str, - error_code: Optional[str] = None, error_msg: Optional[str] = None): - with self._lock: - task = self._tasks.get(task_id) - if task is None: - return None - finished_at = datetime.now().isoformat() - task['status'] = task_status - task['error_code'] = error_code - task['error_msg'] = error_msg - task['finished_at'] = finished_at - task['updated_at'] = finished_at - return dict(task) - - def cancel_task(self, task_id: str): - resp = requests.post(f'{self._parser_base_url}/doc/cancel', json={'task_id': task_id}, timeout=8) - if resp.status_code >= 400: - raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') - result = BaseResponse.model_validate(resp.json()) - if result.code == 200 and result.data and result.data.get('cancel_status'): - self.mark_task_finished(task_id, TaskStatus.CANCELED.value) - return result - - def list_tasks(self, status: Optional[list[str]], page: int, page_size: int): - with self._lock: - items = [dict(task) for task in self._tasks.values()] - if status: - items = [item for item in items if item['status'] in status] - items.sort(key=lambda item: item['created_at'], reverse=True) - total = len(items) - sliced = items[(page - 1) * page_size:page * page_size] - return BaseResponse( - code=200, - msg='success', - data={'items': sliced, 'total': total, 'page': page, 'page_size': page_size}, - ) - - def get_task(self, task_id: str): - with self._lock: - task = self._tasks.get(task_id) - if task is None: - return BaseResponse(code=404, msg='task not found', data=None) - return BaseResponse(code=200, msg='success', data=dict(task)) - - def list_algorithms(self): - return self._upstream_client.list_algorithms() - - def get_algorithm_groups(self, algo_id: str): - return self._upstream_client.get_algorithm_groups(algo_id) - - -def _make_post_func(state: Dict[str, Any]): - def _post_func(task_id: str, task_status: str, error_code: str = None, error_msg: str = None): - adapter = state['adapter'] - callback_url = state['callback_url'] - task = adapter.mark_task_finished(task_id, task_status, error_code, error_msg) - if task is None: - raise RuntimeError(f'untracked callback task: {task_id}') - callback = TaskCallbackRequest( - callback_id=str(uuid4()), - task_id=task_id, - event_type=CallbackEventType.FINISH, - status=_task_status_to_doc_status(task_status), - error_code=error_code, - error_msg=error_msg, - payload={ - 'task_type': task['task_type'], - 'doc_id': task['doc_id'], - 'kb_id': task['kb_id'], - 'algo_id': task['algo_id'], - }, - ) - resp = requests.post(callback_url, json=callback.model_dump(mode='json'), timeout=8) - resp.raise_for_status() - return True - - return _post_func - - def _build_store_conf(root_dir: str) -> Dict[str, Any]: segment_store_path = os.path.join(root_dir, 'segments.db') milvus_store_path = os.path.join(root_dir, 'milvus_lite.db') @@ -297,13 +93,11 @@ def _start_full_stack(args): os.makedirs(storage_dir, exist_ok=True) parser_db = os.path.join(tmp_dir, 'parser.db') doc_db = os.path.join(tmp_dir, 'doc_service.db') - callback_state: Dict[str, Any] = {} parser = DocumentProcessor( port=args.parser_port, db_config=_make_db_config(parser_db), num_workers=args.num_workers, - post_func=_make_post_func(callback_state), ) parser.start() parser_base_url = parser._impl._url.rsplit('/', 1)[0] @@ -346,17 +140,6 @@ def _start_full_stack(args): base_url = server.url.rsplit('/', 1)[0] _wait_http_ok(f'{base_url}/v1/health') - raw_impl = server._raw_impl - raw_impl._lazy_init() - adapter = _RealProcessorTaskAdapter( - parser_base_url=parser_base_url, - manager=raw_impl._manager, - upstream_client=raw_impl._manager._parser_client, - ) - raw_impl._manager._parser_client = adapter - callback_state['adapter'] = adapter - callback_state['callback_url'] = raw_impl._manager._callback_url - print(f'DocService URL: {base_url}', flush=True) print(f'DocService Docs: {base_url}/docs', flush=True) print(f'Parser URL: {parser_base_url}', flush=True) From 9ca28cbd772f5c429d70d8163d5d4a84d7271681 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 11:03:20 +0800 Subject: [PATCH 19/46] lint --- lazyllm/tools/rag/parsing_service/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 5a3358830..2404cc0ec 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -391,7 +391,7 @@ def _parse_task_payload(self, task: dict): self._validate_task_payload(task_type, payload) return task_id, task_type, payload - def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: bool): + def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: bool): # noqa: C901 try: self._in_progress_task = {'task_id': task_id, 'task_type': task_type} if from_queue: From 5bb262503863b3bb629317017eba4026d59eb968 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 14:39:57 +0800 Subject: [PATCH 20/46] tmp code --- lazyllm/tools/rag/doc_service/doc_server.py | 42 +++++++++++++ lazyllm/tools/sql/sql_manager.py | 70 ++++++++++++++++----- 2 files changed, 96 insertions(+), 16 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 827a24957..056e34f06 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -604,6 +604,48 @@ def __init__( ) self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + @staticmethod + def _register_openapi_routes(openapi_app: 'fastapi.FastAPI', impl: 'DocServer._Impl'): + def _find_services(cls): + if '__relay_services__' not in dir(cls): + return + if '__relay_services__' in cls.__dict__: + for (method, path), (name, kw) in cls.__relay_services__.items(): + if getattr(impl.__class__, name) is getattr(cls, name): + route_method = getattr(openapi_app, 'get' if method == 'list' else method) + route_method(path, **kw)(getattr(impl, name)) + for base in cls.__bases__: + _find_services(base) + + app.update() + _find_services(impl.__class__) + + @classmethod + def build_openapi_app(cls, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + openapi_app = fastapi.FastAPI( + title=title, + version=version, + description='OpenAPI schema generated from current DocServer routes.', + ) + impl = cls._Impl( + storage_dir=os.path.join(os.getcwd(), '.doc_service_openapi'), + parser_url='http://127.0.0.1:9966', + ) + cls._register_openapi_routes(openapi_app, impl) + return openapi_app + + @classmethod + def build_openapi_schema(cls, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + return cls.build_openapi_app(title=title, version=version).openapi() + + @classmethod + def export_openapi(cls, output_path: str, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + schema = cls.build_openapi_schema(title=title, version=version) + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as fh: + json.dump(schema, fh, ensure_ascii=False, indent=2, sort_keys=True) + return output_path + def start(self): result = super().start() if self._raw_impl and isinstance(self._impl, ServerModule): diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index da4168054..e37e438b9 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -87,9 +87,13 @@ def _init_tables_by_info(self, tables_info_dict): except pydantic.ValidationError as e: raise ValueError(f'Validate tables_info_dict failed: {str(e)}') - def _sql_type_for(self, py_type: str): + def _sql_type_for(self, py_type: str, *, is_primary_key: bool = False): t = py_type.lower() if self._db_type in ('mysql', 'tidb', 'mysql+pymysql'): + # MySQL/TiDB do not allow TEXT/BLOB columns to be used as primary keys + # without a prefix length. Use VARCHAR for identifier-like key columns. + if is_primary_key and t in ('string', 'text'): + return sqlalchemy.String(255) if t == 'list': return sqlalchemy.JSON if t == 'uuid': @@ -106,8 +110,8 @@ def _create_tables_by_info(self, tables_info: TablesInfo): column_name = column_info.name is_primary = column_info.is_primary_key default_value = column_info.default - # Use text for unsupported column type - real_type = self.PYTYPE_TO_SQL_MAP.get(column_type, sqlalchemy.Text) + # Keep cross-db compatibility while handling MySQL/TiDB PK restrictions. + real_type = self._sql_type_for(column_type, is_primary_key=is_primary) # Handle default value if default_value is not None: attrs[column_name] = sqlalchemy.Column(real_type, nullable=is_nullable, @@ -130,17 +134,58 @@ def _gen_desc_by_info(self, tables_info: TablesInfo) -> dict: desc_dict[table_info.name] = table_comment return desc_dict - def _gen_conn_url(self) -> str: + def _gen_conn_url(self, db_name: str = None) -> str: + db_name = self._db_name if db_name is None else db_name if self._db_type == 'sqlite': - conn_url = f'sqlite:///{self._db_name}{("?" + self._options_str) if self._options_str else ""}' + conn_url = f'sqlite:///{db_name}{("?" + self._options_str) if self._options_str else ""}' else: driver = self.DB_DRIVER_MAP.get(self._db_type if self._db_type != 'tidb' else 'mysql', '') password = quote_plus(self._password) prefix = 'mysql' if self._db_type == 'tidb' else self._db_type + db_path = f'/{db_name}' if db_name else '/' conn_url = (f'{prefix}{("+" + driver) if driver else ""}://{self._user}:{password}@{self._host}' - f':{self._port}/{self._db_name}{("?" + self._options_str) if self._options_str else ""}') + f':{self._port}{db_path}{("?" + self._options_str) if self._options_str else ""}') return conn_url + def _mysql_engine_kwargs(self) -> dict: + kwargs = { + 'pool_size': 10, + 'max_overflow': 20, + 'pool_pre_ping': True, + } + if self._db_type == 'tidb': + kwargs.update({'pool_recycle': 300, 'connect_args': {}, 'echo': False}) + else: + kwargs.update({'pool_recycle': 3600}) + return kwargs + + @staticmethod + def _get_operational_error_code(error: OperationalError): + args = getattr(getattr(error, 'orig', None), 'args', ()) + return args[0] if args else None + + def _ensure_database_exists(self, conn_url: str): + if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb'): + return + probe_engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + try: + with probe_engine.connect(): + return + except OperationalError as e: + if self._get_operational_error_code(e) != 1049: + raise + finally: + probe_engine.dispose() + + admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) + try: + escaped_db_name = self._db_name.replace('`', '``') + with admin_engine.connect() as conn: + conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) + conn.commit() + finally: + admin_engine.dispose() + @property def engine(self): if self._engine is None: @@ -157,16 +202,9 @@ def engine(self): conn.execute(sqlalchemy.text('PRAGMA synchronous=NORMAL')) conn.execute(sqlalchemy.text('PRAGMA busy_timeout=30000')) conn.commit() - elif self._db_type == 'tidb': - self._engine = sqlalchemy.create_engine( - conn_url, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=300, - connect_args={}, - echo=False, - ) + elif self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): + self._ensure_database_exists(conn_url) + self._engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) else: self._engine = sqlalchemy.create_engine( conn_url, From fa3ab7b0328c1dbef7d6100661f4b5aca9072591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 15:12:04 +0800 Subject: [PATCH 21/46] tmp example --- examples/rag/doc_service_standalone.py | 65 +++++++++++++++++--------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py index 1c3e58095..64a16ce6e 100644 --- a/examples/rag/doc_service_standalone.py +++ b/examples/rag/doc_service_standalone.py @@ -15,7 +15,6 @@ import argparse import os -import tempfile import threading import time from typing import Any, Dict @@ -27,6 +26,8 @@ from lazyllm.tools.rag.parsing_service import DocumentProcessor REAL_ALGO_ID = 'real-standalone-algo' +FIXED_DB_ROOT = './tmp/db' +DEFAULT_OPENAPI_PATH = os.path.join(FIXED_DB_ROOT, 'doc_service.openapi.json') def _make_db_config(db_name: str) -> Dict[str, Any]: @@ -40,6 +41,20 @@ def _make_db_config(db_name: str) -> Dict[str, Any]: } +def _prepare_runtime_paths() -> Dict[str, str]: + os.makedirs(FIXED_DB_ROOT, exist_ok=True) + paths = { + 'root_dir': FIXED_DB_ROOT, + 'storage_dir': os.path.join(FIXED_DB_ROOT, 'uploads'), + 'store_dir': os.path.join(FIXED_DB_ROOT, 'store'), + 'parser_db': os.path.join(FIXED_DB_ROOT, 'parser.sqlite'), + 'doc_db': os.path.join(FIXED_DB_ROOT, 'doc_service.sqlite'), + } + os.makedirs(paths['storage_dir'], exist_ok=True) + os.makedirs(paths['store_dir'], exist_ok=True) + return paths + + def _wait_until(predicate, timeout: float = 20.0, interval: float = 0.1): deadline = time.time() + timeout last = None @@ -88,22 +103,18 @@ def _build_store_conf(root_dir: str) -> Dict[str, Any]: def _start_full_stack(args): - tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') - storage_dir = os.path.join(tmp_dir, 'uploads') - os.makedirs(storage_dir, exist_ok=True) - parser_db = os.path.join(tmp_dir, 'parser.db') - doc_db = os.path.join(tmp_dir, 'doc_service.db') + paths = _prepare_runtime_paths() parser = DocumentProcessor( port=args.parser_port, - db_config=_make_db_config(parser_db), + db_config=_make_db_config(paths['parser_db']), num_workers=args.num_workers, ) parser.start() parser_base_url = parser._impl._url.rsplit('/', 1)[0] _wait_http_ok(f'{parser_base_url}/health') - store_conf = _build_store_conf(tmp_dir) + store_conf = _build_store_conf(paths['store_dir']) document = Document( dataset_path=None, name=args.algo_id, @@ -131,8 +142,8 @@ def _start_full_stack(args): ) server = DocServer( - storage_dir=storage_dir, - db_config=_make_db_config(doc_db), + storage_dir=paths['storage_dir'], + db_config=_make_db_config(paths['doc_db']), parser_url=parser_base_url, port=args.port, ) @@ -145,10 +156,11 @@ def _start_full_stack(args): print(f'Parser URL: {parser_base_url}', flush=True) print(f'Parser Docs: {parser_base_url}/docs', flush=True) print(f'Algorithm ID: {args.algo_id}', flush=True) - print(f'Storage Dir: {storage_dir}', flush=True) - print(f'Doc DB: {doc_db}', flush=True) - print(f'Parser DB: {parser_db}', flush=True) - print(f'Tmp Dir: {tmp_dir}', flush=True) + print(f'Storage Dir: {paths["storage_dir"]}', flush=True) + print(f'Store Dir: {paths["store_dir"]}', flush=True) + print(f'Doc DB: {paths["doc_db"]}', flush=True) + print(f'Parser DB: {paths["parser_db"]}', flush=True) + print(f'DB Root: {paths["root_dir"]}', flush=True) try: if args.wait: @@ -164,13 +176,10 @@ def _start_full_stack(args): def _start_doc_server_only(args): - tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_standalone_') - storage_dir = os.path.join(tmp_dir, 'uploads') - os.makedirs(storage_dir, exist_ok=True) - doc_db = os.path.join(tmp_dir, 'doc_service.db') + paths = _prepare_runtime_paths() server = DocServer( - storage_dir=storage_dir, - db_config=_make_db_config(doc_db), + storage_dir=paths['storage_dir'], + db_config=_make_db_config(paths['doc_db']), parser_url=args.parser_url, port=args.port, ) @@ -179,8 +188,9 @@ def _start_doc_server_only(args): print(f'DocService URL: {base_url}', flush=True) print(f'DocService Docs: {base_url}/docs', flush=True) print(f'Parser URL: {args.parser_url}', flush=True) - print(f'Storage Dir: {storage_dir}', flush=True) - print(f'Doc DB: {doc_db}', flush=True) + print(f'Storage Dir: {paths["storage_dir"]}', flush=True) + print(f'Doc DB: {paths["doc_db"]}', flush=True) + print(f'DB Root: {paths["root_dir"]}', flush=True) try: if args.wait: @@ -199,8 +209,19 @@ def main(): parser.add_argument('--algo-id', type=str, default=REAL_ALGO_ID, help='Algorithm id to register in full stack mode.') parser.add_argument('--num-workers', type=int, default=1, help='DocumentProcessor worker count.') parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API inspection.') + parser.add_argument( + '--export-openapi', + type=str, + default=None, + help=f'Export current DocService OpenAPI JSON before startup. Default path example: {DEFAULT_OPENAPI_PATH}', + ) args = parser.parse_args() + if args.export_openapi: + output_path = DocServer.export_openapi(args.export_openapi) + print(f'OpenAPI exported: {output_path}', flush=True) + return + if args.parser_url: _start_doc_server_only(args) else: From 24b9ffada1bbb281fdad1c2b890e2e4ac7d26fa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 15:31:04 +0800 Subject: [PATCH 22/46] fix docnode copy --- lazyllm/tools/rag/doc_node.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lazyllm/tools/rag/doc_node.py b/lazyllm/tools/rag/doc_node.py index ff6c0f5e2..dd5a73537 100644 --- a/lazyllm/tools/rag/doc_node.py +++ b/lazyllm/tools/rag/doc_node.py @@ -290,11 +290,13 @@ def get_text(self, metadata_mode: MetadataMode = MetadataMode.NONE) -> str: def to_dict(self) -> Dict: return dict(content=self._content, embedding=self.embedding, metadata=self.metadata) - def copy(self, global_metadata: dict = None, metadata: dict = None) -> 'DocNode': + def copy(self, global_metadata: dict = None, metadata: dict = None, + preserve_uid: bool = False) -> 'DocNode': node = copy.copy(self) node._copy_source = {'uid': self.uid, RAG_KB_ID: self.global_metadata.get(RAG_KB_ID), RAG_DOC_ID: self.global_metadata.get(RAG_DOC_ID)} - node._uid = str(uuid.uuid4()) + if not preserve_uid: + node._uid = str(uuid.uuid4()) node._metadata = dict(self._metadata or {}) node._global_metadata = dict(self._global_metadata or {}) if metadata: @@ -304,12 +306,12 @@ def copy(self, global_metadata: dict = None, metadata: dict = None) -> 'DocNode' return node def with_score(self, score): - node = self.copy() + node = self.copy(preserve_uid=True) node.relevance_score = score return node def with_sim_score(self, score): - node = self.copy() + node = self.copy(preserve_uid=True) node.similarity_score = score return node From 3c7f948806398b733906d860c20d838c2cd2907d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 16 Mar 2026 19:47:55 +0800 Subject: [PATCH 23/46] fix sqlmanager pg adaption --- lazyllm/tools/sql/sql_manager.py | 67 +++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 14 deletions(-) diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index e37e438b9..685be6732 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -159,30 +159,72 @@ def _mysql_engine_kwargs(self) -> dict: kwargs.update({'pool_recycle': 3600}) return kwargs + @staticmethod + def _default_engine_kwargs() -> dict: + return { + 'pool_size': 10, + 'max_overflow': 20, + 'pool_pre_ping': True, + 'pool_recycle': 3600, + } + @staticmethod def _get_operational_error_code(error: OperationalError): args = getattr(getattr(error, 'orig', None), 'args', ()) return args[0] if args else None + @staticmethod + def _get_operational_error_pgcode(error: OperationalError): + orig = getattr(error, 'orig', None) + return getattr(orig, 'pgcode', None) or getattr(orig, 'sqlstate', None) + + def _is_database_not_found_error(self, error: OperationalError) -> bool: + if self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): + return self._get_operational_error_code(error) == 1049 + if self._db_type == 'postgresql': + if self._get_operational_error_pgcode(error) == '3D000': + return True + error_msg = str(getattr(error, 'orig', error)).lower() + return 'does not exist' in error_msg and 'database' in error_msg + return False + def _ensure_database_exists(self, conn_url: str): - if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb'): + if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb', 'postgresql'): return - probe_engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + engine_kwargs = self._mysql_engine_kwargs() if self._db_type in ('mysql', 'mysql+pymysql', 'tidb') \ + else self._default_engine_kwargs() + probe_engine = sqlalchemy.create_engine(conn_url, **engine_kwargs) try: with probe_engine.connect(): return except OperationalError as e: - if self._get_operational_error_code(e) != 1049: + if not self._is_database_not_found_error(e): raise finally: probe_engine.dispose() - admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) + if self._db_type == 'postgresql': + admin_engine = sqlalchemy.create_engine( + self._gen_conn_url('postgres'), + isolation_level='AUTOCOMMIT', + **self._default_engine_kwargs() + ) + else: + admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) try: - escaped_db_name = self._db_name.replace('`', '``') with admin_engine.connect() as conn: - conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) - conn.commit() + if self._db_type == 'postgresql': + exists = conn.execute( + sqlalchemy.text('SELECT 1 FROM pg_database WHERE datname = :db_name'), + {'db_name': self._db_name} + ).scalar() + if not exists: + escaped_db_name = self._db_name.replace('"', '""') + conn.execute(sqlalchemy.text(f'CREATE DATABASE "{escaped_db_name}"')) + else: + escaped_db_name = self._db_name.replace('`', '``') + conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) + conn.commit() finally: admin_engine.dispose() @@ -205,14 +247,11 @@ def engine(self): elif self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): self._ensure_database_exists(conn_url) self._engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + elif self._db_type == 'postgresql': + self._ensure_database_exists(conn_url) + self._engine = sqlalchemy.create_engine(conn_url, **self._default_engine_kwargs()) else: - self._engine = sqlalchemy.create_engine( - conn_url, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=3600 - ) + self._engine = sqlalchemy.create_engine(conn_url, **self._default_engine_kwargs()) return self._engine @property From ca6e49c7f59fe43f5dafafd5b25af8951559e54e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 17 Mar 2026 14:11:49 +0800 Subject: [PATCH 24/46] Merge remote-tracking branch origin/main into cjh/refact-doc-manager --- lazyllm/tools/rag/parsing_service/worker.py | 226 ++++++++------------ 1 file changed, 92 insertions(+), 134 deletions(-) diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index e8281392e..dcdbe6240 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -126,6 +126,8 @@ def _fail_in_progress_task(self, reason: str): return task_id = self._in_progress_task.get('task_id') task_type = self._in_progress_task.get('task_type') + callback_url = self._in_progress_task.get('callback_url') + task_context_json = self._in_progress_task.get('task_context_json') if task_id and task_type: self._enqueue_finished_task( task_id=task_id, @@ -133,6 +135,8 @@ def _fail_in_progress_task(self, reason: str): task_status=TaskStatus.FAILED, error_code='PRESTOP', error_msg=reason, + callback_url=callback_url, + task_context_json=task_context_json, ) deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -297,6 +301,42 @@ def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: d LOG.error(f'{self._log_prefix(task_id)} Execute update meta task failed: {e}') raise e + @staticmethod + def _resolve_callback_url(payload: dict): + return payload.get('callback_url') or payload.get('feedback_url') + + @staticmethod + def _build_task_context(task_type: str, payload: dict) -> dict: + items = [] + if task_type in (TaskType.DOC_ADD.value, TaskType.DOC_REPARSE.value, TaskType.DOC_UPDATE_META.value): + file_infos = payload.get('file_infos') or [] + items = [{ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'metadata': file_info.get('metadata'), + 'reparse_group': file_info.get('reparse_group'), + } for file_info in file_infos] + elif task_type == TaskType.DOC_DELETE.value: + items = [{'doc_id': doc_id} for doc_id in (payload.get('doc_ids') or [])] + elif task_type == TaskType.DOC_TRANSFER.value: + file_infos = payload.get('file_infos') or [] + items = [{ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'metadata': file_info.get('metadata'), + 'target_doc_id': (file_info.get('transfer_params') or {}).get('target_doc_id'), + 'target_kb_id': (file_info.get('transfer_params') or {}).get('target_kb_id'), + 'transfer_mode': (file_info.get('transfer_params') or {}).get('mode'), + } for file_info in file_infos] + if not items: + items = [{}] + return { + 'task_type': task_type, + 'kb_id': payload.get('kb_id'), + 'algo_id': payload.get('algo_id'), + 'items': items, + } + def _resolve_task_type(self, request: AddDocRequest) -> str: return _resolve_add_doc_task_type(request) @@ -380,20 +420,27 @@ def _enqueue_task_from_payload(self, task: dict): def _parse_task_payload(self, task: dict): task_type = task.get('task_type') if task_type == TaskType.DOC_DELETE.value: - task_info = DeleteDocRequest(**task) + task_info = DeleteDocRequest.model_validate(task) elif task_type == TaskType.DOC_UPDATE_META.value: - task_info = UpdateMetaRequest(**task) + task_info = UpdateMetaRequest.model_validate(task) else: - task_info = AddDocRequest(**task) + task_info = AddDocRequest.model_validate(task) task_type = task_type or self._resolve_task_type(task_info) task_id = task_info.task_id - payload = task_info.model_dump() + payload = task_info.model_dump(mode='json') self._validate_task_payload(task_type, payload) return task_id, task_type, payload def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: bool): # noqa: C901 + callback_url = self._resolve_callback_url(payload) + task_context_json = json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False) try: - self._in_progress_task = {'task_id': task_id, 'task_type': task_type} + self._in_progress_task = { + 'task_id': task_id, + 'task_type': task_type, + 'callback_url': callback_url, + 'task_context_json': task_context_json, + } if from_queue: self._start_lease_renewal(task_id) algo_id = payload.get('algo_id') @@ -402,6 +449,13 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo LOG.info(f'{self._log_prefix(task_id)} Start processing task: ' f'{self._summarize_task_payload(task_type, payload)}') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.WORKING, + callback_url=callback_url, + task_context_json=task_context_json, + ) processor = self._get_or_create_processor(algo_id) if task_type == TaskType.DOC_ADD.value: @@ -417,8 +471,15 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo else: raise ValueError(f'{self._log_prefix(task_id)} Unknown task type: {task_type}') - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, - error_code='200', error_msg='success') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.SUCCESS, + error_code='200', + error_msg='success', + callback_url=callback_url, + task_context_json=task_context_json, + ) if from_queue: deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -428,8 +489,15 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo except Exception as e: LOG.error(f'{self._log_prefix(task_id)} Failed to run task: {e}, {traceback.format_exc()}') if task_id and task_type: - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, - error_code=type(e).__name__, error_msg=str(e)) + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.FAILED, + error_code=type(e).__name__, + error_msg=str(e), + callback_url=callback_url, + task_context_json=task_context_json, + ) if from_queue: deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -470,66 +538,6 @@ def _poll_task(self): exclude_task_types=exclude_types, ) - @staticmethod - def _resolve_callback_url(payload: dict): - return payload.get('callback_url') or payload.get('feedback_url') - - @staticmethod - def _build_task_context(task_type: str, payload: dict) -> dict: - items = [] - if task_type in (TaskType.DOC_ADD.value, TaskType.DOC_REPARSE.value, TaskType.DOC_UPDATE_META.value): - file_infos = payload.get('file_infos') or [] - items = [{ - 'doc_id': file_info.get('doc_id'), - 'file_path': file_info.get('file_path'), - 'metadata': file_info.get('metadata'), - 'reparse_group': file_info.get('reparse_group'), - } for file_info in file_infos] - elif task_type == TaskType.DOC_DELETE.value: - items = [{'doc_id': doc_id} for doc_id in (payload.get('doc_ids') or [])] - if not items: - items = [{}] - return { - 'task_type': task_type, - 'kb_id': payload.get('kb_id'), - 'algo_id': payload.get('algo_id'), - 'items': items, - } - - @staticmethod - def _infer_task_type(task_type: str, payload: dict) -> str: - if task_type == TaskType.DOC_DELETE.value: - return task_type - if task_type == TaskType.DOC_UPDATE_META.value: - return task_type - - file_infos = payload.get('file_infos') or [] - has_reparse = any(file_info.get('reparse_group') is not None for file_info in file_infos) - if has_reparse: - return TaskType.DOC_REPARSE.value - return task_type - - def _parse_task_payload(self, task_data: dict): - task_id = task_data.get('task_id') - task_type = task_data.get('task_type') - if not task_id: - raise ValueError('task_id is required') - if not task_type: - raise ValueError('task_type is required') - - if task_type == TaskType.DOC_DELETE.value: - payload = DeleteDocRequest.model_validate(task_data).model_dump(mode='json') - elif task_type == TaskType.DOC_UPDATE_META.value: - payload = UpdateMetaRequest.model_validate(task_data).model_dump(mode='json') - else: - payload = AddDocRequest.model_validate(task_data).model_dump(mode='json') - - task_type = self._infer_task_type(task_type, payload) - return task_id, task_type, payload - - def _summarize_task_payload(self, task_type: str, payload: dict) -> str: - return json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False, sort_keys=True) - def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: TaskStatus, error_code: str = None, error_msg: str = None, callback_url: str = None, task_context_json: str = None): @@ -548,7 +556,7 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task if task_status == TaskStatus.WORKING: LOG.info(f'{self._log_prefix(task_id)} Task started') elif task_status == TaskStatus.SUCCESS: - LOG.info(f'{self._log_prefix(task_id)} Task {task_id} completed successfully') + LOG.info(f'{self._log_prefix(task_id)} Task completed successfully') else: LOG.error(f'{self._log_prefix(task_id)} Task completed with status {task_status}: {error_msg}') except Exception as e: @@ -557,76 +565,24 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task def _worker_impl(self): # noqa: C901 while not self._shutdown: try: - task_data = self._waiting_task_queue.dequeue() - if not task_data: - time.sleep(0.1) - continue - - raw_payload = json.loads(task_data.get('message')) - task_id, task_type, payload = self._parse_task_payload({ - 'task_id': task_data.get('task_id'), - 'task_type': task_data.get('task_type'), - **raw_payload, - }) - algo_id = payload.get('algo_id') - if not algo_id: - raise ValueError(f'[DocumentProcessorWorker._Impl] task_id {task_id} is missing algo_id in ' - f'payload: {payload}') - callback_url = self._resolve_callback_url(payload) - task_context_json = json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False) - task_summary = self._summarize_task_payload(task_type, payload) - - LOG.info(f'[DocumentProcessorWorker._Impl] Start processing task {task_id}, type: {task_type},' - f' algo_id: {algo_id}, payload={task_summary}') - self._enqueue_finished_task( - task_id=task_id, - task_type=task_type, - task_status=TaskStatus.WORKING, - callback_url=callback_url, - task_context_json=task_context_json, - ) - - processor = self._get_or_create_processor(algo_id) - if task_type == TaskType.DOC_ADD.value: - self._exec_add_task(processor, task_id, payload) - elif task_type == TaskType.DOC_REPARSE.value: - self._exec_reparse_task(processor, task_id, payload) - elif task_type == TaskType.DOC_DELETE.value: - self._exec_delete_task(processor, task_id, payload) - elif task_type == TaskType.DOC_UPDATE_META.value: - self._exec_update_meta_task(processor, task_id, payload) - else: - raise ValueError(f'[DocumentProcessorWorker._Impl] Unknown task type: {task_type}') - - self._enqueue_finished_task( - task_id=task_id, - task_type=task_type, - task_status=TaskStatus.SUCCESS, - error_code='200', - error_msg='success', - callback_url=callback_url, - task_context_json=task_context_json, - ) + task_data = self._poll_task() except Exception as e: - LOG.error(f'[DocumentProcessorWorker._Impl] Failed to run task {task_id}: {e},' - f' {traceback.format_exc()}') - if task_id and task_type: - callback_url = locals().get('callback_url') - task_context_json = locals().get('task_context_json') - self._enqueue_finished_task( - task_id=task_id, - task_type=task_type, - task_status=TaskStatus.FAILED, - error_code=type(e).__name__, - error_msg=str(e), - callback_url=callback_url, - task_context_json=task_context_json, - ) + LOG.error(f'{self._log_prefix()} [Worker] poll_task failed: {e}, {traceback.format_exc()}') time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue if task_data: + payload = None + callback_url = None + task_context_json = None try: payload = json.loads(task_data.get('message')) + callback_url = self._resolve_callback_url(payload) if isinstance(payload, dict) else None + if isinstance(payload, dict): + task_context_json = json.dumps( + self._build_task_context(task_data['task_type'], payload), + ensure_ascii=False, + ) + self._validate_task_payload(task_data['task_type'], payload) except Exception as e: task_id = task_data.get('task_id') task_type = task_data.get('task_type') @@ -640,6 +596,8 @@ def _worker_impl(self): # noqa: C901 task_status=TaskStatus.FAILED, error_code=type(e).__name__, error_msg=str(e), + callback_url=callback_url, + task_context_json=task_context_json, ) deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} From 77cc90d274699bcaabcf16b2847e3a338be0af4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 17 Mar 2026 14:29:26 +0800 Subject: [PATCH 25/46] fix reparse whole file --- lazyllm/tools/rag/parsing_service/impl.py | 26 +++++++++++++++-------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index cd520663d..e662364c9 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -90,6 +90,21 @@ def store(self) -> _DocumentStore: def reader(self) -> DirectoryReader: return self._reader + @staticmethod + def _prepare_doc_inputs(input_files: List[str], ids: Optional[List[str]] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + kb_id: Optional[str] = None) -> tuple[List[str], List[Dict[str, Any]], str]: + ids = ids or [gen_docid(path) for path in input_files] + normalized_metadatas = metadatas or [{} for _ in input_files] + for i, (doc_id, path) in enumerate(zip(ids, input_files)): + metadata = normalized_metadatas[i] or {} + metadata.setdefault(RAG_DOC_ID, doc_id) + metadata.setdefault(RAG_DOC_PATH, path) + metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) + normalized_metadatas[i] = metadata + resolved_kb_id = normalized_metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id + return ids, normalized_metadatas, resolved_kb_id + def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # noqa: C901 metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None, transfer_mode: Optional[str] = None, target_kb_id: Optional[str] = None, @@ -98,14 +113,7 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no try: if not input_files: return add_start = time.time() - if not ids: ids = [gen_docid(path) for path in input_files] - if metadatas is None: - metadatas = [{} for _ in input_files] - for metadata, doc_id, path in zip(metadatas, ids, input_files): - metadata.setdefault(RAG_DOC_ID, doc_id) - metadata.setdefault(RAG_DOC_PATH, path) - metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) - kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id + ids, metadatas, kb_id = self._prepare_doc_inputs(input_files, ids, metadatas, kb_id) load_start = time.time() if transfer_mode is None: @@ -277,7 +285,7 @@ def _reparse_docs(self, group_name: str, doc_ids: List[str], doc_paths: List[str kb_id: str = None, **kwargs): if not metadatas: raise ValueError('metadatas is required for reparse') - kb_id = metadatas[0].get(RAG_KB_ID, None) if kb_id is None else kb_id + doc_ids, metadatas, kb_id = self._prepare_doc_inputs(doc_paths, doc_ids, metadatas, kb_id) if group_name == 'all': preloaded_root_nodes = self._reader.load_data(doc_paths, metadatas, split_nodes_by_type=True) self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id) From c73fababe338bb2bd0efadd35cb0d5f473951992 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 18 Mar 2026 14:14:50 +0800 Subject: [PATCH 26/46] add chunk list api --- lazyllm/tools/rag/doc_service/doc_manager.py | 425 ++++++++++++++----- lazyllm/tools/rag/doc_service/doc_server.py | 29 +- lazyllm/tools/rag/parsing_service/base.py | 2 +- lazyllm/tools/rag/parsing_service/server.py | 86 +++- lazyllm/tools/rag/parsing_service/worker.py | 48 ++- lazyllm/tools/rag/store/document_store.py | 15 +- 6 files changed, 481 insertions(+), 124 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 96fb654ff..07cef5ad4 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -33,7 +33,6 @@ ReparseRequest, SourceType, TaskCallbackRequest, - TaskCreateRequest, TaskType, TransferRequest, UploadRequest, @@ -45,7 +44,7 @@ CancelTaskRequest as ParsingCancelTaskRequest, DeleteDocRequest as ParsingDeleteDocRequest, FileInfo as ParsingFileInfo, - TaskStatus, + TransferParams as ParsingTransferParams, UpdateMetaRequest as ParsingUpdateMetaRequest, ) @@ -128,7 +127,7 @@ def _get_with_fallback(self, paths: List[str], params: Optional[Dict[str, Any]] def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path: str, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, - callback_url: Optional[str] = None): + callback_url: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None): req = ParsingAddDocRequest( task_id=task_id, algo_id=algo_id, @@ -140,6 +139,10 @@ def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path doc_id=doc_id, metadata=metadata or {}, reparse_group=reparse_group, + transfer_params=( + ParsingTransferParams.model_validate(transfer_params) + if transfer_params is not None else None + ), )], ) data = self._post('/doc/add', req.model_dump(mode='json')) @@ -193,6 +196,17 @@ def get_algorithm_groups(self, algo_id: str): return BaseResponse(code=404, msg='algo not found', data=None) raise + def list_doc_chunks(self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int, page_size: int): + data = self._get('/doc/chunks', params={ + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'group': group, + 'offset': offset, + 'page_size': page_size, + }) + return BaseResponse.model_validate(data) + class DocManager: def __init__( @@ -487,6 +501,14 @@ def _record_callback(self, callback_id: str, task_id: str): session.rollback() return False + def _forget_callback_record(self, callback_id: str, task_id: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(CALLBACK_RECORDS_TABLE_INFO['name']) + session.query(Record).filter( + Record.callback_id == callback_id, + Record.task_id == task_id, + ).delete() + def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb_id: str, algo_id: str, status: DocStatus, message: Optional[Dict[str, Any]] = None): now = now_ts() @@ -695,7 +717,10 @@ def _prepare_kb_delete_items(self, kb_id: str) -> Dict[str, Any]: default_algo_id = binding['algo_id'] if binding is not None else '__default__' items = [] for doc_id in self._list_kb_doc_ids(kb_id): - snapshot = self._get_parse_snapshot(doc_id, kb_id, default_algo_id) or self._get_latest_parse_snapshot(doc_id, kb_id) + snapshot = ( + self._get_parse_snapshot(doc_id, kb_id, default_algo_id) + or self._get_latest_parse_snapshot(doc_id, kb_id) + ) if snapshot is None or snapshot.get('status') == DocStatus.DELETED.value: items.append({'action': 'purge_local', 'doc_id': doc_id}) continue @@ -850,13 +875,15 @@ def _call_parser_client(self, method, *args, **kwargs): def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, - reparse_group: Optional[str] = None): + reparse_group: Optional[str] = None, parser_kb_id: Optional[str] = None, + transfer_params: Optional[Dict[str, Any]] = None): if task_type in (TaskType.DOC_ADD, TaskType.DOC_TRANSFER): if not file_path: raise RuntimeError(f'file_path is required for task_type {task_type.value}') task_resp = self._call_parser_client( self._parser_client.add_doc, - task_id, algo_id, kb_id, doc_id, file_path, metadata, callback_url=self._callback_url, + task_id, algo_id, parser_kb_id or kb_id, doc_id, file_path, metadata, + callback_url=self._callback_url, transfer_params=transfer_params, ) elif task_type == TaskType.DOC_REPARSE: if not file_path: @@ -887,6 +914,8 @@ def _enqueue_task( idempotency_key: Optional[str] = None, priority: int = 0, file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, + parser_kb_id: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None, + extra_message: Optional[Dict[str, Any]] = None, ): task_id = str(uuid4()) task_message = { @@ -897,8 +926,12 @@ def _enqueue_task( 'metadata': metadata, 'reparse_group': reparse_group, } + if extra_message: + task_message.update(extra_message) if cleanup_policy: task_message['cleanup_policy'] = cleanup_policy + if transfer_params: + task_message['transfer_params'] = transfer_params task_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING self._create_task_record(task_id, task_type, doc_id, kb_id, algo_id, task_status, message=task_message) parse_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING @@ -922,6 +955,7 @@ def _enqueue_task( self._create_parser_task( task_id, doc_id, kb_id, algo_id, task_type, file_path=file_path, metadata=metadata, reparse_group=reparse_group, + parser_kb_id=parser_kb_id, transfer_params=transfer_params, ) except Exception as exc: finished_at = now_ts() @@ -1008,9 +1042,15 @@ def _prepare_delete_items(self, request: DeleteRequest) -> List[Dict[str, Any]]: prepared_items = [] for doc_id in request.doc_ids: doc = self._get_doc(doc_id) - snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) or self._get_latest_parse_snapshot(doc_id, request.kb_id) + snapshot = ( + self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) + or self._get_latest_parse_snapshot(doc_id, request.kb_id) + ) if doc is None or not self._has_kb_document(request.kb_id, doc_id): - if snapshot is not None and snapshot.get('status') in (DocStatus.DELETING.value, DocStatus.DELETED.value): + if snapshot is not None and snapshot.get('status') in ( + DocStatus.DELETING.value, + DocStatus.DELETED.value, + ): prepared_items.append({ 'doc_id': doc_id, 'action': 'noop', @@ -1043,6 +1083,87 @@ def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[D prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) return prepared_items + def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: + prepared_items = [] + seen_pairs = set() + seen_targets = set() + for item in request.items: + if item.mode not in ('move', 'copy'): + raise DocServiceError( + 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} + ) + item_key = (item.doc_id, item.source_kb_id, item.target_kb_id) + if item_key in seen_pairs: + raise DocServiceError( + 'E_INVALID_PARAM', + 'duplicate transfer item detected', + {'doc_id': item.doc_id, 'source_kb_id': item.source_kb_id, 'target_kb_id': item.target_kb_id}, + ) + seen_pairs.add(item_key) + target_key = (item.doc_id, item.target_kb_id, item.target_algo_id) + if target_key in seen_targets: + raise DocServiceError( + 'E_INVALID_PARAM', + 'duplicate transfer target detected', + { + 'doc_id': item.doc_id, + 'target_kb_id': item.target_kb_id, + 'target_algo_id': item.target_algo_id, + }, + ) + seen_targets.add(target_key) + if item.source_algo_id != item.target_algo_id: + raise DocServiceError( + 'E_INVALID_PARAM', + 'transfer across different algorithms is not supported', + { + 'doc_id': item.doc_id, + 'source_algo_id': item.source_algo_id, + 'target_algo_id': item.target_algo_id, + }, + ) + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(item.source_kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) + self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) + self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') + if self._has_kb_document(item.target_kb_id, item.doc_id): + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc already exists in target kb: {item.doc_id}', + {'doc_id': item.doc_id, 'target_kb_id': item.target_kb_id}, + ) + source_snapshot = self._get_parse_snapshot(item.doc_id, item.source_kb_id, item.source_algo_id) + if source_snapshot is None or source_snapshot.get('status') != DocStatus.SUCCESS.value: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc transfer requires source parse status SUCCESS: {item.doc_id}', + { + 'doc_id': item.doc_id, + 'source_kb_id': item.source_kb_id, + 'source_algo_id': item.source_algo_id, + 'status': source_snapshot.get('status') if source_snapshot else None, + }, + ) + prepared_items.append({ + 'doc_id': item.doc_id, + 'source_kb_id': item.source_kb_id, + 'source_algo_id': item.source_algo_id, + 'target_kb_id': item.target_kb_id, + 'target_algo_id': item.target_algo_id, + 'mode': item.mode, + 'file_path': doc.get('path'), + 'metadata': _from_json(doc.get('meta')), + 'transfer_params': { + 'mode': 'mv' if item.mode == 'move' else 'cp', + 'target_algo_id': item.target_algo_id, + 'target_doc_id': item.doc_id, + 'target_kb_id': item.target_kb_id, + }, + }) + return prepared_items + def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: self._validate_kb_algorithm(request.kb_id, request.algo_id) prepared_items = self._prepare_upload_items(request) @@ -1143,7 +1264,11 @@ def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: raise DocServiceError( 'E_STATE_CONFLICT', cancel_resp.msg, - cancel_resp.data if isinstance(cancel_resp.data, dict) else {'task_id': snapshot['current_task_id']}, + ( + cancel_resp.data + if isinstance(cancel_resp.data, dict) + else {'task_id': snapshot['current_task_id']} + ), ) items.append({ 'doc_id': doc_id, @@ -1175,36 +1300,50 @@ def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: return items def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: + prepared_items = self._prepare_transfer_items(request) items: List[Dict[str, Any]] = [] - for item in request.items: - if item.mode not in ('move', 'copy'): - raise DocServiceError( - 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} + for item in prepared_items: + task_id = None + try: + self._ensure_kb_document(item['target_kb_id'], item['doc_id']) + task_id, snapshot = self._enqueue_task( + item['doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, + idempotency_key=request.idempotency_key, + file_path=item['file_path'], + metadata=item['metadata'], + parser_kb_id=item['source_kb_id'], + transfer_params=item['transfer_params'], + extra_message={ + 'source_kb_id': item['source_kb_id'], + 'source_algo_id': item['source_algo_id'], + 'target_kb_id': item['target_kb_id'], + 'target_algo_id': item['target_algo_id'], + 'mode': item['mode'], + }, ) - doc = self._get_doc(item.doc_id) - if doc is None or not self._has_kb_document(item.source_kb_id, item.doc_id): - raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') - self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) - self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) - self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') - self._ensure_kb_document(item.target_kb_id, item.doc_id) - if item.mode == 'move': - self._remove_kb_document(item.source_kb_id, item.doc_id) - task_id, snapshot = self._enqueue_task( - item.doc_id, item.target_kb_id, item.target_algo_id, TaskType.DOC_TRANSFER, - idempotency_key=request.idempotency_key, - file_path=doc.get('path'), - metadata=_from_json(doc.get('meta')), - ) + error_code = None + error_msg = None + accepted = True + except Exception as exc: + snapshot = self._get_parse_snapshot(item['doc_id'], item['target_kb_id'], item['target_algo_id']) or {} + task_id = task_id or snapshot.get('current_task_id') + error_code = snapshot.get('last_error_code') + if not error_code: + error_code = exc.biz_code if isinstance(exc, DocServiceError) else type(exc).__name__ + error_msg = snapshot.get('last_error_msg') or (exc.msg if isinstance(exc, DocServiceError) else str(exc)) + accepted = False items.append({ - 'doc_id': item.doc_id, + 'doc_id': item['doc_id'], 'task_id': task_id, - 'source_kb_id': item.source_kb_id, - 'target_kb_id': item.target_kb_id, - 'source_algo_id': item.source_algo_id, - 'target_algo_id': item.target_algo_id, - 'mode': item.mode, - 'status': snapshot['status'], + 'source_kb_id': item['source_kb_id'], + 'target_kb_id': item['target_kb_id'], + 'source_algo_id': item['source_algo_id'], + 'target_algo_id': item['target_algo_id'], + 'mode': item['mode'], + 'status': snapshot.get('status', DocStatus.FAILED.value), + 'accepted': accepted, + 'error_code': error_code, + 'error_msg': error_msg, }) return items @@ -1279,95 +1418,110 @@ def _resolve_callback_task(self, callback: TaskCallbackRequest): } return None - def on_task_callback(self, callback: TaskCallbackRequest): + def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 if not self._record_callback(callback.callback_id, callback.task_id): return {'ack': True, 'deduped': True, 'ignored_reason': None} - task_data = self._resolve_callback_task(callback) - if task_data is None: - return {'ack': True, 'ignored_reason': 'task_not_found'} - doc_id = task_data['doc_id'] - kb_id = task_data['kb_id'] - algo_id = task_data['algo_id'] - task_type = TaskType(task_data['task_type']) - snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) - if snapshot and snapshot.get('current_task_id') and snapshot['current_task_id'] != callback.task_id: - return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} - if task_data.get('status') == DocStatus.CANCELED.value and callback.status != DocStatus.CANCELED: - return {'ack': True, 'deduped': False, 'ignored_reason': 'canceled_task_callback'} + try: + task_data = self._resolve_callback_task(callback) + if task_data is None: + return {'ack': True, 'ignored_reason': 'task_not_found'} + doc_id = task_data['doc_id'] + kb_id = task_data['kb_id'] + algo_id = task_data['algo_id'] + task_type = TaskType(task_data['task_type']) + snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) + if snapshot and snapshot.get('current_task_id') and snapshot['current_task_id'] != callback.task_id: + return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} + if task_data.get('status') == DocStatus.CANCELED.value and callback.status != DocStatus.CANCELED: + return {'ack': True, 'deduped': False, 'ignored_reason': 'canceled_task_callback'} + + if callback.event_type == CallbackEventType.START: + self._update_task_record( + callback.task_id, + status=DocStatus.WORKING.value, + started_at=now_ts(), + finished_at=None, + error_code=None, + error_msg=None, + ) + start_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WORKING + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=start_status, + **self._build_snapshot_update( + snapshot, + task_type=task_type, + current_task_id=callback.task_id, + started_at=now_ts(), + finished_at=None, + error_code=None, + error_msg=None, + failed_stage=None, + ), + ) + if task_type == TaskType.DOC_ADD: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) + elif task_type == TaskType.DOC_DELETE: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) + return {'ack': True, 'deduped': False, 'ignored_reason': None} + + final_status = callback.status + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.SUCCESS: + final_status = DocStatus.DELETED + failed_stage = None + if final_status == DocStatus.FAILED: + failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' + task_message = task_data.get('message') or {} + cleanup_policy = task_message.get('cleanup_policy') + + if ( + task_type == TaskType.DOC_TRANSFER + and final_status == DocStatus.SUCCESS + and task_message.get('mode') == 'move' + ): + source_kb_id = task_message.get('source_kb_id') + if source_kb_id and source_kb_id != kb_id: + self._remove_kb_document(source_kb_id, doc_id) + self._delete_parse_snapshots(doc_id, source_kb_id) + self._sync_doc_upload_status(doc_id) - if callback.event_type == CallbackEventType.START: self._update_task_record( callback.task_id, - status=DocStatus.WORKING.value, - started_at=now_ts(), - finished_at=None, - error_code=None, - error_msg=None, + status=final_status.value, + error_code=callback.error_code, + error_msg=callback.error_msg, + finished_at=now_ts(), ) - start_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WORKING + self._upsert_parse_snapshot( doc_id=doc_id, kb_id=kb_id, algo_id=algo_id, - status=start_status, + status=final_status, **self._build_snapshot_update( snapshot, task_type=task_type, current_task_id=callback.task_id, - started_at=now_ts(), - finished_at=None, - error_code=None, - error_msg=None, - failed_stage=None, + error_code=callback.error_code, + error_msg=callback.error_msg, + failed_stage=failed_stage, + finished_at=now_ts(), ), ) - if task_type == TaskType.DOC_ADD: - self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) - elif task_type == TaskType.DOC_DELETE: - self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) - return {'ack': True, 'deduped': False, 'ignored_reason': None} - final_status = callback.status - if task_type == TaskType.DOC_DELETE and final_status == DocStatus.SUCCESS: - final_status = DocStatus.DELETED - failed_stage = None - if final_status == DocStatus.FAILED: - failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' - task_message = task_data.get('message') or {} - cleanup_policy = task_message.get('cleanup_policy') - - self._update_task_record( - callback.task_id, - status=final_status.value, - error_code=callback.error_code, - error_msg=callback.error_msg, - finished_at=now_ts(), - ) + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: + self._remove_kb_document(kb_id, doc_id) + self._apply_doc_upload_status(doc_id, task_type, final_status) + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED and cleanup_policy == 'purge': + self._purge_deleted_kb_doc_data(kb_id, doc_id) + self._finalize_kb_deletion_if_empty(kb_id) - self._upsert_parse_snapshot( - doc_id=doc_id, - kb_id=kb_id, - algo_id=algo_id, - status=final_status, - **self._build_snapshot_update( - snapshot, - task_type=task_type, - current_task_id=callback.task_id, - error_code=callback.error_code, - error_msg=callback.error_msg, - failed_stage=failed_stage, - finished_at=now_ts(), - ), - ) - - if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: - self._remove_kb_document(kb_id, doc_id) - self._apply_doc_upload_status(doc_id, task_type, final_status) - if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED and cleanup_policy == 'purge': - self._purge_deleted_kb_doc_data(kb_id, doc_id) - self._finalize_kb_deletion_if_empty(kb_id) - - return {'ack': True, 'deduped': False, 'ignored_reason': None} + return {'ack': True, 'deduped': False, 'ignored_reason': None} + except Exception: + self._forget_callback_record(callback.callback_id, callback.task_id) + raise def list_docs( self, @@ -1548,8 +1702,57 @@ def get_algorithm_info(self, algo_id: str): return data raise DocServiceError('E_NOT_FOUND', f'algo not found: {algo_id}') - def list_chunks(self, page: int = 1, page_size: int = 20): - return {'items': [], 'total': 0, 'page': page, 'page_size': page_size} + def list_chunks( + self, + kb_id: str, + doc_id: str, + group: str, + algo_id: str = '__default__', + page: int = 1, + page_size: int = 20, + offset: Optional[int] = None, + ): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required', {'kb_id': kb_id}) + if not doc_id: + raise DocServiceError('E_INVALID_PARAM', 'doc_id is required', {'doc_id': doc_id}) + if not group: + raise DocServiceError('E_INVALID_PARAM', 'group is required', {'group': group}) + self._validate_kb_algorithm(kb_id, algo_id) + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}', {'kb_id': kb_id, 'doc_id': doc_id}) + groups = self.get_algo_groups(algo_id) + if not any(item.get('name') == group for item in groups): + raise DocServiceError( + 'E_INVALID_PARAM', + f'invalid group: {group}', + {'algo_id': algo_id, 'group': group}, + ) + page = max(page, 1) + page_size = max(page_size, 1) + offset = (page - 1) * page_size if offset is None else max(offset, 0) + resp = self._parser_client.list_doc_chunks( + algo_id=algo_id, + kb_id=kb_id, + doc_id=doc_id, + group=group, + offset=offset, + page_size=page_size, + ) + if resp.code == 404: + raise DocServiceError('E_NOT_FOUND', resp.msg, {'kb_id': kb_id, 'doc_id': doc_id, 'group': group}) + if resp.code == 400: + raise DocServiceError('E_INVALID_PARAM', resp.msg, {'kb_id': kb_id, 'doc_id': doc_id, 'group': group}) + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + data = dict(resp.data or {}) + data['page'] = page + data['page_size'] = page_size + data['offset'] = offset + data.setdefault('items', []) + data.setdefault('total', 0) + return data def health(self): return { @@ -1579,7 +1782,11 @@ def list_kbs( if keyword: like_expr = f'%{keyword}%' query = query.filter( - sqlalchemy.or_(Kb.kb_id.like(like_expr), Kb.display_name.like(like_expr), Kb.description.like(like_expr)) + sqlalchemy.or_( + Kb.kb_id.like(like_expr), + Kb.display_name.like(like_expr), + Kb.description.like(like_expr), + ) ) if status: query = query.filter(Kb.status.in_(status)) diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 056e34f06..1308a8f60 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -18,7 +18,6 @@ from .base import TransferRequest from .base import UploadRequest, AddFileItem from .doc_manager import DocManager -from ..parsing_service.base import TaskStatus, TaskType class DocServer(ModuleBase): @@ -128,7 +127,8 @@ def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = candidate = os.path.join(self._storage_dir, f'{prefix}-{idx}{suffix}') if candidate not in reserved_paths and not os.path.exists(candidate): return candidate - return os.path.join(self._storage_dir, f'{prefix}-{hashlib.sha256(safe_name.encode()).hexdigest()[:8]}{suffix}') + digest = hashlib.sha256(safe_name.encode()).hexdigest()[:8] + return os.path.join(self._storage_dir, f'{prefix}-{digest}{suffix}') def _run_upload(self, request: UploadRequest, payload: Optional[Dict[str, Any]] = None): idem_payload = payload or self._build_upload_payload(request) @@ -421,9 +421,26 @@ def get_algorithm_info_impl(self, algo_id: str): return self._run(lambda: self._manager.get_algorithm_info(algo_id)) @app.get('/v1/chunks') - def list_chunks(self, page: int = 1, page_size: int = 20): + def list_chunks( + self, + kb_id: str, + doc_id: str, + group: str, + algo_id: str = '__default__', + page: int = 1, + page_size: int = 20, + offset: Optional[int] = None, + ): self._lazy_init() - return self._run(lambda: self._manager.list_chunks(page=page, page_size=page_size)) + return self._run(lambda: self._manager.list_chunks( + kb_id=kb_id, + doc_id=doc_id, + group=group, + algo_id=algo_id, + page=page, + page_size=page_size, + offset=offset, + )) @app.post('/v1/tasks/batch') async def get_tasks_batch(self, request: 'fastapi.Request'): @@ -730,8 +747,8 @@ def list_kbs(self, **kwargs): def get_kb(self, kb_id: str): return self._dispatch('get_kb', kb_id) - def list_chunks(self, page: int = 1, page_size: int = 20): - return self._dispatch('list_chunks', page, page_size) + def list_chunks(self, **kwargs): + return self._dispatch('list_chunks', **kwargs) def list_algorithms(self): return self._dispatch('list_algorithms_impl') diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 467d8e935..690786098 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, BeforeValidator, model_validator, model_validator +from pydantic import BaseModel, Field, BeforeValidator, model_validator from typing import Dict, List, Optional, Any, Annotated from enum import Enum from uuid import uuid4 diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index c1ae313c1..674aefb3b 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -18,9 +18,8 @@ from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, - TaskStatus, TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, - - _calculate_task_score, _resolve_add_doc_task_type + TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, + _calculate_task_score ) from .worker import DocumentProcessorWorker as Worker from .queue import _SQLBasedQueue as Queue @@ -479,6 +478,87 @@ def _get_algo_group_info_data(self, algo_id: str): data.append(group_info) return data + @staticmethod + def _format_chunk_item(segment: Dict[str, Any]) -> Dict[str, Any]: + return { + 'uid': segment.get('uid'), + 'doc_id': segment.get('doc_id'), + 'kb_id': segment.get('kb_id'), + 'group': segment.get('group'), + 'number': segment.get('number', 0), + 'content': segment.get('content'), + 'type': segment.get('type'), + 'parent': segment.get('parent'), + 'metadata': segment.get('meta', {}), + 'global_metadata': segment.get('global_meta', {}), + 'answer': segment.get('answer', ''), + 'image_keys': segment.get('image_keys', []), + } + + def _list_doc_chunks_data( + self, + algo_id: str, + kb_id: str, + doc_id: str, + group: str, + offset: int = 0, + limit: int = 20, + ) -> Dict[str, Any]: + algorithm = self._get_algo(algo_id) + if algorithm is None: + raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + info = load_obj(algorithm.get('info_pickle')) + store: _DocumentStore = info['store'] # type: ignore + node_groups = info.get('node_groups', {}) + if group not in node_groups or not store.is_group_active(group): + raise fastapi.HTTPException(status_code=400, detail=f'Invalid group {group}') + offset = max(offset, 0) + limit = max(limit, 1) + segments, total = store.get_segments( + doc_ids={doc_id}, + kb_id=kb_id, + group=group, + offset=offset, + limit=limit, + return_total=True, + sort_by_number=True, + ) + return { + 'items': [self._format_chunk_item(segment) for segment in segments], + 'total': total, + 'offset': offset, + 'page_size': limit, + } + + @app.get('/doc/chunks') + def list_doc_chunks( + self, + algo_id: str = '__default__', + kb_id: Optional[str] = None, + doc_id: Optional[str] = None, + group: Optional[str] = None, + offset: int = 0, + page_size: int = 20, + ): + self._lazy_init() + if self._shutdown: + raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') + if not kb_id: + raise fastapi.HTTPException(status_code=400, detail='kb_id is required') + if not doc_id: + raise fastapi.HTTPException(status_code=400, detail='doc_id is required') + if not group: + raise fastapi.HTTPException(status_code=400, detail='group is required') + data = self._list_doc_chunks_data( + algo_id=algo_id, + kb_id=kb_id, + doc_id=doc_id, + group=group, + offset=offset, + limit=page_size, + ) + return BaseResponse(code=200, msg='success', data=data) + @staticmethod def _resolve_add_task_type(file_infos) -> str: has_reparse = False diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index dcdbe6240..37a6ad38f 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -26,7 +26,7 @@ class DocumentProcessorWorker(ModuleBase): class _Impl(): def __init__(self, db_config: dict = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, - high_priority_only: bool = False, poll_mode: str = 'direct'): + high_priority_only: bool = False, poll_mode: str = 'thread'): self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._shutdown = False self._processors: dict[str, _Processor] = {} # algo_id -> _Processor @@ -254,6 +254,7 @@ def _exec_reparse_task( def _exec_transfer_task(self, processor: _Processor, task_id: str, payload: dict): try: + self._validate_transfer_payload(payload) file_infos = payload.get('file_infos') kb_id = payload.get('kb_id', None) input_files = [] @@ -340,6 +341,47 @@ def _build_task_context(task_type: str, payload: dict) -> dict: def _resolve_task_type(self, request: AddDocRequest) -> str: return _resolve_add_doc_task_type(request) + @staticmethod + def _validate_transfer_payload(payload: dict): # noqa: C901 + file_infos = payload.get('file_infos') + if not isinstance(file_infos, list) or not file_infos: + raise ValueError('file_infos is required for task_type DOC_TRANSFER') + transfer_mode = None + target_kb_id = None + target_algo_id = None + target_doc_ids = set() + request_algo_id = payload.get('algo_id') + for idx, file_info in enumerate(file_infos): + transfer_params = file_info.get('transfer_params') + if not isinstance(transfer_params, dict) or not transfer_params: + raise ValueError(f'transfer_params is required for task_type DOC_TRANSFER at index {idx}') + current_mode = transfer_params.get('mode') + current_target_kb_id = transfer_params.get('target_kb_id') + current_target_algo_id = transfer_params.get('target_algo_id') + current_target_doc_id = transfer_params.get('target_doc_id') + if current_mode not in ('cp', 'mv'): + raise ValueError('transfer_params.mode must be one of [cp, mv]') + if not current_target_kb_id: + raise ValueError('transfer_params.target_kb_id is required for task_type DOC_TRANSFER') + if not current_target_algo_id: + raise ValueError('transfer_params.target_algo_id is required for task_type DOC_TRANSFER') + if not current_target_doc_id: + raise ValueError('transfer_params.target_doc_id is required for task_type DOC_TRANSFER') + if transfer_mode is not None and transfer_mode != current_mode: + raise ValueError('transfer_params.mode must be the same for all files') + if target_kb_id is not None and target_kb_id != current_target_kb_id: + raise ValueError('transfer_params.target_kb_id must be the same for all files') + if target_algo_id is not None and target_algo_id != current_target_algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + if request_algo_id is not None and current_target_algo_id != request_algo_id: + raise ValueError('transfer_params.target_algo_id must match request.algo_id') + if current_target_doc_id in target_doc_ids: + raise ValueError('transfer_params.target_doc_id must be unique for all files') + transfer_mode = current_mode + target_kb_id = current_target_kb_id + target_algo_id = current_target_algo_id + target_doc_ids.add(current_target_doc_id) + def _validate_task_payload(self, task_type: str, payload: dict): if not isinstance(payload, dict): raise ValueError('payload must be a dict') @@ -356,6 +398,8 @@ def _validate_task_payload(self, task_type: str, payload: dict): doc_ids = payload.get('doc_ids') if not isinstance(doc_ids, list) or not doc_ids: raise ValueError('doc_ids is required for task_type DOC_DELETE') + if task_type == TaskType.DOC_TRANSFER.value: + self._validate_transfer_payload(payload) def _summarize_task_payload(self, task_type: str, payload: dict) -> str: summary = { @@ -672,7 +716,7 @@ def shutdown(self): def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, high_priority_only: bool = False, - poll_mode: str = 'direct'): + poll_mode: str = 'thread'): super().__init__() self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._num_workers = num_workers diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index 7344502df..43a1e4f82 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -218,11 +218,13 @@ def remove_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, group: Optional[str] = None, kb_id: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, return_total: bool = False, - numbers: Optional[Set] = None, **kwargs) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + numbers: Optional[Set] = None, sort_by_number: bool = False, + **kwargs) -> Union[List[DocNode], Tuple[List[DocNode], int]]: try: result = self.get_segments(uids=uids, doc_ids=doc_ids, group=group, kb_id=kb_id, numbers=numbers, limit=limit, - offset=offset, return_total=return_total, **kwargs) + offset=offset, return_total=return_total, + sort_by_number=sort_by_number, **kwargs) if return_total: segments, total = result return [self._deserialize_node(segment) for segment in segments], total @@ -234,7 +236,8 @@ def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = N def get_segments(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, group: Optional[str] = None, kb_id: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, return_total: bool = False, - numbers: Optional[Set] = None, **kwargs) -> Union[List[dict], Tuple[List[dict], int]]: + numbers: Optional[Set] = None, sort_by_number: bool = False, + **kwargs) -> Union[List[dict], Tuple[List[dict], int]]: # get a set of segments by uids # get the segments of the whole file -- doc ids only # get the segments of a certain group for one file -- doc ids and group (kb_id is optional) @@ -251,6 +254,8 @@ def get_segments(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] LOG.warning(f'[_DocumentStore - {self._algo_name}] Group {group} is not active, skip') continue segments.extend(self.impl.get(self._gen_collection_name(group), criteria, **kwargs)) + if sort_by_number: + segments = self._sort_segments_by_number(segments) total = len(segments) segments = self._slice_segments(segments, limit, offset) return (segments, total) if return_total else segments @@ -288,6 +293,10 @@ def _slice_segments(segments: List[dict], limit: Optional[int], offset: int) -> return segments[offset:end] return segments + @staticmethod + def _sort_segments_by_number(segments: List[dict]) -> List[dict]: + return sorted(segments, key=lambda segment: (segment.get('number', 0), segment.get('uid', ''))) + def _resolve_groups(self, group: Optional[str]) -> List[str]: if not group: return sorted(self._activated_groups) From 318e966316a1cbdc368245d050a90c111c27dafc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 18 Mar 2026 18:46:25 +0800 Subject: [PATCH 27/46] fix opensearch query --- .../tools/rag/store/segment/elasticsearch_store.py | 12 ++++++++++++ lazyllm/tools/rag/store/segment/opensearch_store.py | 12 ++++++++++++ 2 files changed, 24 insertions(+) diff --git a/lazyllm/tools/rag/store/segment/elasticsearch_store.py b/lazyllm/tools/rag/store/segment/elasticsearch_store.py index e2678645e..50010a4b4 100644 --- a/lazyllm/tools/rag/store/segment/elasticsearch_store.py +++ b/lazyllm/tools/rag/store/segment/elasticsearch_store.py @@ -318,7 +318,19 @@ def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: vals = [vals] return {'query': {'ids': {'values': vals}}} + exact_match_fields = {'doc_id', 'kb_id', 'group', 'parent'} + def _add_clause(key, val): + if key in exact_match_fields: + clauses = [] + if isinstance(val, list): + clauses.append({'terms': {key: val}}) + clauses.append({'terms': {f'{key}.keyword': val}}) + else: + clauses.append({'term': {key: val}}) + clauses.append({'term': {f'{key}.keyword': val}}) + must_clauses.append({'bool': {'should': clauses, 'minimum_should_match': 1}}) + return if isinstance(val, list): must_clauses.append({'terms': {key: val}}) else: diff --git a/lazyllm/tools/rag/store/segment/opensearch_store.py b/lazyllm/tools/rag/store/segment/opensearch_store.py index 696169316..7aa914dc5 100644 --- a/lazyllm/tools/rag/store/segment/opensearch_store.py +++ b/lazyllm/tools/rag/store/segment/opensearch_store.py @@ -263,7 +263,19 @@ def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: vals = [vals] return {'query': {'ids': {'values': vals}}} + exact_match_fields = {'doc_id', 'kb_id', 'group', 'parent'} + def _add_clause(key, val): + if key in exact_match_fields: + clauses = [] + if isinstance(val, list): + clauses.append({'terms': {key: val}}) + clauses.append({'terms': {f'{key}.keyword': val}}) + else: + clauses.append({'term': {key: val}}) + clauses.append({'term': {f'{key}.keyword': val}}) + must_clauses.append({'bool': {'should': clauses, 'minimum_should_match': 1}}) + return if isinstance(val, list): must_clauses.append({'terms': {key: val}}) else: From 20e9373b5b829052c4577285632525acbfbc5813 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 18 Mar 2026 19:03:12 +0800 Subject: [PATCH 28/46] support real offset for segment store --- lazyllm/tools/rag/doc_impl.py | 19 +++++++++--- lazyllm/tools/rag/document.py | 20 ++++++++----- lazyllm/tools/rag/store/document_store.py | 20 +++++++++++++ .../tools/rag/store/hybrid/hybrid_store.py | 11 +++++-- lazyllm/tools/rag/store/hybrid/map_store.py | 30 +++++++++++++++++-- .../rag/store/segment/elasticsearch_store.py | 26 ++++++++++++++++ .../rag/store/segment/opensearch_store.py | 28 ++++++++++++++++- 7 files changed, 136 insertions(+), 18 deletions(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 0792a070f..ecd55d065 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -537,11 +537,22 @@ def _register_schema_set(self, schema_set: Type[BaseModel], kb_id: Optional[str] return set_id def _get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: self._lazy_init() - return self._store.get_nodes(uids=uids, doc_ids=doc_ids, group=group, - kb_id=kb_id, numbers=numbers, display=True) + return self._store.get_nodes( + uids=uids, + doc_ids=doc_ids, + group=group, + kb_id=kb_id, + numbers=numbers, + limit=limit, + offset=offset, + return_total=return_total, + sort_by_number=sort_by_number, + display=True, + ) def _get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index f35e23086..f46d96b0b 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional, Dict, Union, List, Type, Set +from typing import Callable, Optional, Dict, Union, List, Type, Set, Tuple from functools import cached_property from pydantic import BaseModel import lazyllm @@ -401,9 +401,12 @@ def register_schema_set(self, schema_set: Type[BaseModel], kb_id: Optional[str] return self._forward('_register_schema_set', schema_set, kb_id, force_refresh) def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: - return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers) + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + return self._forward( + '_get_nodes', uids, doc_ids, group, kb_id, numbers, limit, offset, return_total, sort_by_number, + ) def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: @@ -434,9 +437,12 @@ def forward(self, *args, **kw): return self._forward('retrieve', *args, **kw) def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: - return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers) + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + return self._forward( + '_get_nodes', uids, doc_ids, group, kb_id, numbers, limit, offset, return_total, sort_by_number, + ) def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index 43a1e4f82..e6460e0a6 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -248,6 +248,20 @@ def get_segments(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] limit, offset = self._normalize_pagination(limit, offset) criteria = self._build_get_criteria(uids, doc_ids, kb_id, numbers, kwargs.get('parent')) groups = self._resolve_groups(group) + if self._can_use_native_segment_pagination(groups, sort_by_number): + result = self.impl.get( + self._gen_collection_name(groups[0]), + criteria, + limit=limit, + offset=offset, + return_total=return_total, + sort_by_number=sort_by_number, + **kwargs, + ) + if return_total: + segments, total = result if isinstance(result, tuple) else (result, len(result)) + return segments, total + return result[0] if isinstance(result, tuple) else result segments = [] for group in groups: if not self.is_group_active(group): @@ -302,6 +316,12 @@ def _resolve_groups(self, group: Optional[str]) -> List[str]: return sorted(self._activated_groups) return [group] + def _can_use_native_segment_pagination(self, groups: List[str], sort_by_number: bool) -> bool: + if len(groups) != 1 or not sort_by_number: + return False + store = getattr(self.impl, 'segment_store', self.impl) + return store.__class__.__name__ in {'MapStore', 'OpenSearchStore', 'ElasticSearchStore'} + def _build_get_criteria(self, uids: Optional[List[str]], doc_ids: Optional[Set], kb_id: Optional[str], numbers: Optional[Set] = None, parent: Optional[Union[str, List[str]]] = None) -> Dict[str, Any]: diff --git a/lazyllm/tools/rag/store/hybrid/hybrid_store.py b/lazyllm/tools/rag/store/hybrid/hybrid_store.py index 5b3c1b788..b9ecc7902 100644 --- a/lazyllm/tools/rag/store/hybrid/hybrid_store.py +++ b/lazyllm/tools/rag/store/hybrid/hybrid_store.py @@ -37,9 +37,13 @@ def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs @override def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: res_segments = self.segment_store.get(collection_name=collection_name, criteria=criteria, **kwargs) - if not res_segments: return [] + total = None + if isinstance(res_segments, tuple): + res_segments, total = res_segments + if not res_segments: + return ([], total or 0) if total is not None else [] uids = [item.get('uid') for item in res_segments] - res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids}, **kwargs) + res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids}) data = {} for item in res_segments: @@ -50,7 +54,8 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - else: raise ValueError(f'[HybridStore - get] uid {item["uid"]} in vector store' ' but not found in segment store') - return list(data.values()) + ordered = [data[item.get('uid')] for item in res_segments if item.get('uid') in data] + return (ordered, total) if total is not None else ordered @override def search(self, collection_name: str, query: str, query_embedding: Optional[Union[dict, List[float]]] = None, diff --git a/lazyllm/tools/rag/store/hybrid/map_store.py b/lazyllm/tools/rag/store/hybrid/map_store.py index 6dbe98162..3a7d20ab9 100644 --- a/lazyllm/tools/rag/store/hybrid/map_store.py +++ b/lazyllm/tools/rag/store/hybrid/map_store.py @@ -179,24 +179,48 @@ def _remove_uid(uid: str, use_discard: bool) -> None: @override def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) if self._sqlite_first: with self._lock: conn = self._open_conn() cur = conn.cursor() self._ensure_table(cur, collection_name) where, args = self._build_where(criteria) + order_clause = ' ORDER BY number ASC, uid ASC' if sort_by_number else '' + page_clause = '' + page_args = () + if limit is not None: + page_clause = ' LIMIT ? OFFSET ?' + page_args = (limit, offset) + elif offset > 0: + page_clause = ' LIMIT -1 OFFSET ?' + page_args = (offset,) + total = None + if return_total: + cur.execute(f'SELECT COUNT(*) FROM {collection_name}{where}', args) + total = cur.fetchone()[0] cur.execute(f'''SELECT uid, doc_id, "group", content, meta, global_meta, type, number, kb_id, excluded_embed_metadata_keys, excluded_llm_metadata_keys, parent, answer, image_keys - FROM {collection_name}{where}''', args) + FROM {collection_name}{where}{order_clause}{page_clause}''', args + page_args) rows = cur.fetchall() res = [] for r in rows: item = self._deserialize_data(r) res.append(item) - return res + return (res, total) if return_total else res else: uids = self._get_uids_by_criteria(collection_name, criteria) - return [self._uid2data[uid] for uid in uids if uid in self._uid2data] + res = [self._uid2data[uid] for uid in uids if uid in self._uid2data] + if sort_by_number: + res = sorted(res, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(res) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + res = res[offset:end] + return (res, total) if return_total else res def _build_where(self, criteria: dict): if not criteria: diff --git a/lazyllm/tools/rag/store/segment/elasticsearch_store.py b/lazyllm/tools/rag/store/segment/elasticsearch_store.py index 50010a4b4..817fcc770 100644 --- a/lazyllm/tools/rag/store/segment/elasticsearch_store.py +++ b/lazyllm/tools/rag/store/segment/elasticsearch_store.py @@ -211,6 +211,10 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - results: List[dict] = [] criteria = dict(criteria) if criteria else {} + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) # Query by primary key(mget) if criteria and self._primary_key in criteria: vals = criteria.pop(self._primary_key) @@ -224,6 +228,28 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - seg = self._transform_segment(doc) if seg: results.append(seg) + if sort_by_number: + results = sorted(results, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(results) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + results = results[offset:end] + return (results, total) if return_total else results + elif sort_by_number and (limit is not None or offset > 0 or return_total): + body = self._construct_criteria(criteria) or {'query': {'match_all': {}}} + body['sort'] = [{'number': {'order': 'asc'}}, {'_id': {'order': 'asc'}}] + if offset > 0: + body['from'] = offset + if limit is not None: + body['size'] = limit + elif offset > 0: + body['size'] = 10000 + if return_total: + body['track_total_hits'] = True + resp = self._client.search(index=collection_name, body=body) + results = [self._transform_segment(hit) for hit in resp['hits']['hits']] + total = resp['hits']['total']['value'] if return_total else len(results) + return (results, total) if return_total else results else: helpers = elasticsearch.helpers diff --git a/lazyllm/tools/rag/store/segment/opensearch_store.py b/lazyllm/tools/rag/store/segment/opensearch_store.py index 7aa914dc5..a210e4612 100644 --- a/lazyllm/tools/rag/store/segment/opensearch_store.py +++ b/lazyllm/tools/rag/store/segment/opensearch_store.py @@ -156,13 +156,17 @@ def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs return False @override - def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: + def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: # noqa: C901 try: if not self._client.indices.exists(index=collection_name): LOG.warning(f'[OpenSearchStore - get] Index {collection_name} does not exist') return [] results: List[dict] = [] criteria = dict(criteria) if criteria else {} + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) if criteria and self._primary_key in criteria: vals = criteria.pop(self._primary_key) if not isinstance(vals, list): @@ -172,6 +176,28 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - for doc in resp['docs']: if doc.get('found', False): results.append(self._transform_segment(doc)) + if sort_by_number: + results = sorted(results, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(results) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + results = results[offset:end] + return (results, total) if return_total else results + elif sort_by_number and (limit is not None or offset > 0 or return_total): + body = self._construct_criteria(criteria) or {'query': {'match_all': {}}} + body['sort'] = [{'number': {'order': 'asc'}}, {'_id': {'order': 'asc'}}] + if offset > 0: + body['from'] = offset + if limit is not None: + body['size'] = limit + elif offset > 0: + body['size'] = 10000 + if return_total: + body['track_total_hits'] = True + resp = self._client.search(index=collection_name, body=body) + results = [self._transform_segment(hit) for hit in resp['hits']['hits']] + total = resp['hits']['total']['value'] if return_total else len(results) + return (results, total) if return_total else results else: spec = importlib.util.find_spec('opensearchpy.helpers') helpers = importlib.util.module_from_spec(spec) From 786ecbe4bd6ff0b4a803502f7b4a358e9b36ba6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 23 Mar 2026 18:24:12 +0800 Subject: [PATCH 29/46] refactor: expose public doc and processor APIs --- lazyllm/tools/rag/doc_service/doc_server.py | 3 +++ lazyllm/tools/rag/parsing_service/server.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 1308a8f60..ce928ebe1 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -738,6 +738,9 @@ def get_task_info(self, task_id: str): def get_task(self, task_id: str): return self._dispatch('get_task', task_id) + def set_runtime_callback_url(self, callback_url: str): + return self._dispatch('set_runtime_callback_url', callback_url) + def cancel_task(self, task_id: str): return self._dispatch('cancel_task_by_id', task_id) diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index 674aefb3b..d8ee1f307 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -878,6 +878,11 @@ def set_callback_url(self, callback_url: Optional[str]): raise RuntimeError('set_callback_url is only supported in local server mode') return self._dispatch('set_callback_url', callback_url) + @property + def url(self): + impl = self._impl + return impl._url if isinstance(impl, ServerModule) else impl.url + def _dispatch(self, method: str, *args, **kwargs): try: impl = self._impl From 3fb0f51fd755bd0ead22dcebb958798500443fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 24 Mar 2026 19:03:44 +0800 Subject: [PATCH 30/46] fix file transfer --- lazyllm/docs/tools/tool_rag.py | 36 +- lazyllm/tools/rag/__init__.py | 3 - lazyllm/tools/rag/doc_impl.py | 181 +++--- lazyllm/tools/rag/doc_manager.py | 3 +- lazyllm/tools/rag/doc_service/base.py | 4 + lazyllm/tools/rag/doc_service/doc_manager.py | 118 +++- lazyllm/tools/rag/doc_to_db/extractor.py | 8 +- lazyllm/tools/rag/document.py | 49 +- lazyllm/tools/rag/utils.py | 3 +- tests/basic_tests/RAG/test_doc_manager.py | 294 --------- .../RAG/test_doc_service_doc_manager.py | 559 ++++++++++++++++++ .../RAG/test_doc_service_doc_server.py | 248 ++++++++ .../basic_tests/RAG/test_doc_service_mock.py | 315 +--------- tests/basic_tests/RAG/test_document.py | 190 +++--- 14 files changed, 1130 insertions(+), 881 deletions(-) delete mode 100644 tests/basic_tests/RAG/test_doc_manager.py create mode 100644 tests/basic_tests/RAG/test_doc_service_doc_manager.py create mode 100644 tests/basic_tests/RAG/test_doc_service_doc_server.py diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index bd022c1b7..69b43bb5d 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -1419,7 +1419,7 @@ 将算法/知识库绑定到指定 schema 集合;若提供 schema_set 会先注册;可选 force_refresh 覆盖已有绑定并清理旧数据。 Args: - algo_id (str, optional): 算法/Document 名称,默认 DocListManager.DEFAULT_GROUP_NAME。 + algo_id (str, optional): 算法/Document 名称,默认 `__default__`。 kb_id (str, optional): 知识库 ID,默认 DEFAULT_KB_ID。 schema_set_id (str, optional): 已有 schema 集合 ID。 schema_set (Type[BaseModel], optional): 新 schema,传入则会注册后绑定。 @@ -1433,7 +1433,7 @@ Bind an algo/kb pair to a schema set; optionally register a provided schema_set first; with force_refresh you can override an existing binding and purge old records. Args: - algo_id (str, optional): Algorithm/Document name, defaults to DocListManager.DEFAULT_GROUP_NAME. + algo_id (str, optional): Algorithm/Document name, defaults to `__default__`. kb_id (str, optional): Knowledge base id, defaults to DEFAULT_KB_ID. schema_set_id (str, optional): Existing schema set id to bind. schema_set (Type[BaseModel], optional): Schema to register and bind if no id is provided. @@ -1470,7 +1470,7 @@ Args: data (Union[str, List[DocNode]]): 文本或 DocNode 列表(需同一文档)。 - algo_id (str, optional): 算法/Document 名称,默认 DocListManager.DEFAULT_GROUP_NAME。 + algo_id (str, optional): 算法/Document 名称,默认 `__default__`。 schema_set_id (str, optional): 指定使用的 schema 集合 ID。 schema_set (Type[BaseModel], optional): 动态注册并使用的 schema。 @@ -1483,7 +1483,7 @@ Args: data (Union[str, List[DocNode]]): Text or list of DocNodes from a single document. - algo_id (str, optional): Algorithm/Document name, defaults to DocListManager.DEFAULT_GROUP_NAME. + algo_id (str, optional): Algorithm/Document name, defaults to `__default__`. schema_set_id (str, optional): Schema set id to use. schema_set (Type[BaseModel], optional): Schema to register and use if no id is provided. @@ -5591,10 +5591,10 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: # rag/doc_manager.py add_chinese_doc('rag.DocManager', """ -DocManager类管理文档列表及相关操作,并通过API提供文档上传、删除、分组等功能。 +已废弃。请改用 `lazyllm.tools.rag.doc_service.DocServer`。 Args: - dlm (DocListManager): 文档列表管理器,用于处理具体的文档操作。 + dlm (DocListManager): 旧文档列表管理器。 """) @@ -5756,10 +5756,10 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: """) add_english_doc('rag.DocManager', """ -The `DocManager` class manages document lists and related operations, providing APIs for uploading, deleting, and grouping documents. +Deprecated. Use `lazyllm.tools.rag.doc_service.DocServer` instead. Args: - dlm (DocListManager): Document list manager responsible for handling document-related operations. + dlm (DocListManager): Legacy document list manager. """) add_english_doc('rag.DocManager.document', """ @@ -6021,7 +6021,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: # rag/utils.py add_chinese_doc('rag.utils.DocListManager', """\ -抽象基类,用于管理文档列表和监控文档目录变化。 +已废弃。请改用 `Document(dataset_path=..., enable_path_monitoring=...)` 或 `DocServer`。 Args: path:要监控的文档目录路径。 @@ -6197,7 +6197,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: 说明: - 方法首先通过辅助函数 `_add_doc_records` 创建文档记录。 - - 文件添加后,会自动关联到默认的知识库组 (`DocListManager.DEFAULT_GROUP_NAME`)。 + - 文件添加后,会自动关联到默认的知识库组(`__default__`)。 - 批量处理确保在添加大量文件时具有良好的可扩展性。 ''') @@ -6315,7 +6315,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_english_doc('rag.utils.DocListManager', """\ -Abstract base class for managing document lists and monitoring changes in a document directory. +Deprecated. Use `Document(dataset_path=..., enable_path_monitoring=...)` or `DocServer`. Args: path: Path of the document directory to monitor. @@ -6469,7 +6469,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: If `details=True`, returns a list of detailed rows with additional metadata. Notes: - The method first creates document records using the `_add_doc_records` helper function. - - After the files are added, they are automatically linked to the default KB group (`DocListManager.DEFAULT_GROUP_NAME`). + - After the files are added, they are automatically linked to the default KB group (`__default__`). - Batch processing ensures scalability when adding a large number of files. ''') @@ -6488,7 +6488,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: Notes: - The method first creates document records using the helper function _add_doc_records. -- After the files are added, they are automatically linked to the default knowledge base group (DocListManager.DEFAULT_GROUP_NAME). +- After the files are added, they are automatically linked to the default knowledge base group (`__default__`). - Batch processing ensures good scalability when adding a large number of files. @@ -6608,15 +6608,14 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: add_example('rag.utils.DocListManager', ''' >>> import lazyllm ->>> from lazyllm.rag.utils import DocListManager ->>> manager = DocListManager(path='your_file_path/', name="test_manager", enable_path_monitoring=False) +>>> # Deprecated. Use Document(dataset_path='your_file_path', enable_path_monitoring=True) >>> added_docs = manager.add_files([test_file_list]) >>> manager.enable_path_monitoring(True) >>> deleted = manager.delete_files([delete_file_list]) ''') add_chinese_doc('rag.utils.SqliteDocListManager', '''\ -基于 SQLite 的文档管理器,用于本地文件的持久化存储、状态管理与元信息追踪。 +已废弃。请改用 `Document(dataset_path=...)` 或新的 `doc_service` 数据库。 该类继承自 DocListManager,利用 SQLite 数据库存储文档记录。适用于管理具有唯一标识符的本地文档资源,并提供便捷的插入、查询、更新与状态过滤接口,支持可选的路径监控功能。 @@ -6627,7 +6626,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_english_doc('rag.utils.SqliteDocListManager', '''\ -SQLite-based document manager for persistent local file storage, status tracking, and metadata management. +Deprecated. Use `Document(dataset_path=...)` or the new `doc_service` database instead. This class inherits from DocListManager and uses a SQLite backend to store document records. It is suitable for managing locally identified documents with support for inserting, querying, updating, and filtering based on status. Optional file path monitoring is also supported. @@ -6638,8 +6637,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_example('rag.utils.SqliteDocListManager', '''\ ->>> from lazyllm.tools.rag.utils import SqliteDocListManager ->>> manager = SqliteDocListManager(path="./data", name="docs.sqlite") +>>> # Deprecated. Use Document(dataset_path=...) or DocServer instead. >>> manager.insert({"uid": "doc_001", "name": "example.txt", "status": "ready"}) >>> print(manager.get("doc_001")) >>> files = manager.list_files(limit=5, details=True) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 86834668d..f80c57eee 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -17,7 +17,6 @@ MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, MineruPDFReader) from .dataReader import SimpleDirectoryReader, FileReader -from .doc_manager import DocManager, DocListManager from .global_metadata import GlobalMetadataDesc as DocField from .data_type import DataType from .index_base import IndexBase @@ -62,8 +61,6 @@ 'VideoAudioReader', 'SimpleDirectoryReader', 'MineruPDFReader', - 'DocManager', - 'DocListManager', 'DocField', 'DataType', 'IndexBase', diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index ecd55d065..8cf651832 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -1,4 +1,4 @@ -import json +import os import threading import time from enum import Enum @@ -14,14 +14,13 @@ from .store.document_store import _DocumentStore from .doc_node import DocNode from .data_loaders import DirectoryReader -from .utils import DocListManager, is_sparse, _get_default_db_config -from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_KB_ID +from .utils import gen_docid, is_sparse, _get_default_db_config +from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID from .data_type import DataType from .parsing_service import _Processor, DocumentProcessor from .embed_wrapper import _EmbedWrapper from .doc_to_db import SchemaExtractor from dataclasses import dataclass -from itertools import repeat _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) @@ -68,9 +67,11 @@ class DocImpl: _builtin_node_groups: Dict[str, Dict] = {} _global_node_groups: Dict[str, Dict] = {} _registered_file_reader: Dict[str, Callable] = {} + DEFAULT_GROUP_NAME = '__default__' - def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, - doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, + def __init__(self, embed: Dict[str, Callable], dataset_path: Optional[str] = None, + enable_path_monitoring: bool = False, + doc_files: Optional[List[str]] = None, kb_group_name: Optional[str] = None, global_metadata_desc: Dict[str, GlobalMetadataDesc] = None, store: Optional[Union[Dict, LazyLLMStoreBase]] = None, processor: Optional[DocumentProcessor] = None, algo_name: Optional[str] = None, @@ -78,9 +79,15 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): super().__init__() self._local_file_reader: Dict[str, Callable] = {} - self._kb_group_name = kb_group_name or DocListManager.DEFAULT_GROUP_NAME - self._dlm, self._doc_files = dlm, doc_files + self._kb_group_name = kb_group_name or self.DEFAULT_GROUP_NAME + self._dataset_path = dataset_path + self._enable_path_monitoring = bool(enable_path_monitoring and dataset_path and doc_files is None) + self._doc_files = doc_files self._reader = DirectoryReader(None, self._local_file_reader, DocImpl._registered_file_reader) + self._local_monitor_thread: Optional[threading.Thread] = None + self._local_monitor_continue = False + self._local_monitor_interval = 10 + self._local_monitor_lock = threading.Lock() self.node_groups: Dict[str, Dict] = { LAZY_ROOT_NAME: dict(parent=None, display_name='Original Source', group_type=NodeGroupType.ORIGINAL), LAZY_IMAGE_GROUP: dict(parent=None, display_name='Image Node', group_type=NodeGroupType.OTHER) @@ -173,7 +180,7 @@ def _lazy_init(self) -> None: self._init_node_groups() self._create_schema_extractor() self._create_store() - cloud = not (self._dlm or self._doc_files is not None) + cloud = not (self._dataset_path or self._doc_files is not None) self._resolve_index_pending_registrations() if self._processor: @@ -185,16 +192,10 @@ def _lazy_init(self) -> None: self._schema_extractor, self._display_name, self._description) # init files when `cloud` is False - if not cloud and self._store.is_group_empty(LAZY_ROOT_NAME): - ids, pathes, metadatas = self._list_files(upload_status=DocListManager.Status.success) - self._processor.add_doc(pathes, ids, metadatas) - if pathes and self._dlm: - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.success) - if self._dlm: - self._daemon = threading.Thread(target=self.worker) - self._daemon.daemon = True - self._daemon.start() + if not cloud: + self._sync_local_dataset() + if self._dataset_path and self._enable_path_monitoring: + self._start_local_monitoring() def _resolve_index_pending_registrations(self): for index_type, index_cls, index_args, index_kwargs in self._index_pending_registrations: @@ -297,80 +298,82 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None): self._reader._lazy_init.flag.reset() def _add_doc_to_store(self, input_files: List[str], ids: List[str], metadatas: List[Dict[str, Any]]): - success_ids, failed_ids = [], [] - for filepath, doc_id, metadata in zip(input_files, ids or repeat(None), metadatas or repeat(None)): + for filepath, doc_id, metadata in zip(input_files, ids, metadatas): try: self._processor.add_doc([filepath], [doc_id], [metadata] if metadata is not None else None) - success_ids.append(doc_id) except Exception as e: LOG.error(f'Error adding document {doc_id} ({filepath}) to store: {e}') - failed_ids.append(doc_id) - - if success_ids: - self._dlm.update_kb_group(cond_file_ids=success_ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.success) - if failed_ids: - self._dlm.update_kb_group(cond_file_ids=failed_ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.failed) - - def _batch_call(self, func: Callable, *args, batch_size: int = 10, **kwargs): - batch_count = next((len(arg) for arg in args if isinstance(arg, (tuple, list))), 0) - for i in range(0, batch_count, batch_size): - func(*[arg[i:i + batch_size] if isinstance(arg, (list, tuple)) else arg for arg in args], **kwargs) - - def worker(self): - while True: - # Apply meta changes - rows = self._dlm.fetch_docs_changed_meta(self._kb_group_name) - for row in rows: - new_meta_dict = json.loads(row[1]) if row[1] else {} - self._processor.update_doc_meta(doc_id=row[0], metadata=new_meta_dict) - - # Step 1: do doc-parsing, highest priority - docs = self._dlm.get_docs_need_reparse(group=self._kb_group_name) - if docs: - filepaths = [doc.path for doc in docs] - ids = [doc.doc_id for doc in docs] - metadatas = [json.loads(doc.meta) if doc.meta else None for doc in docs] - # update status and need_reparse - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.working, new_need_reparse=False) - self._delete_doc_from_store(doc_ids=ids) - self._batch_call(self._add_doc_to_store, filepaths, ids, metadatas, batch_size=10) - - # Step 2: After doc is deleted from related kb_group, delete doc from db - if self._kb_group_name == DocListManager.DEFAULT_GROUP_NAME: - self._dlm.delete_unreferenced_doc() - - # Step 3: do doc-deleting - ids, files, metadatas = self._list_files(status=DocListManager.Status.deleting) - if files: - self._delete_doc_from_store(doc_ids=ids) - self._dlm.delete_files_from_kb_group(ids, self._kb_group_name) - - # Step 4: do doc-adding - ids, files, metadatas = self._list_files(status=DocListManager.Status.waiting, - upload_status=DocListManager.Status.success) - if files: - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.working) - self._batch_call(self._add_doc_to_store, files, ids, metadatas) - - time.sleep(10) - - def _list_files( - self, status: Union[str, List[str]] = DocListManager.Status.all, - upload_status: Union[str, List[str]] = DocListManager.Status.all - ) -> Tuple[List[str], List[str], List[Dict]]: - if self._doc_files is not None: return None, self._doc_files, None - if not self._dlm: return [], [], [] - ids, paths, metadatas = [], [], [] - for row in self._dlm.list_kb_group_files(group=self._kb_group_name, status=status, - upload_status=upload_status, details=True): - ids.append(row[0]) - paths.append(row[1]) - metadatas.append(json.loads(row[3]) if row[3] else {}) - return ids, paths, metadatas + + def _list_dataset_files(self) -> List[str]: + if not self._dataset_path or not os.path.exists(self._dataset_path): + return [] + if not os.path.isdir(self._dataset_path): + filename = os.path.basename(self._dataset_path) + if filename.startswith('.'): + return [] + return [self._dataset_path] if os.path.isfile(self._dataset_path) else [] + + files = [] + for root, dirs, names in os.walk(os.path.abspath(self._dataset_path)): + path_parts = root.split(os.sep) + if any(part.startswith('.') for part in path_parts if part): + continue + dirs[:] = [name for name in dirs if not name.startswith('.')] + files.extend(os.path.join(root, name) for name in names if not name.startswith('.')) + return sorted(files) + + def _list_local_files(self) -> Tuple[List[str], List[str], List[Dict[str, Any]]]: + paths = list(self._doc_files) if self._doc_files is not None else self._list_dataset_files() + ids = [gen_docid(path) for path in paths] + return ids, paths, [{} for _ in paths] + + def _list_store_root_docs(self) -> Dict[str, str]: + docs = {} + for node in self._store.get_nodes(group=LAZY_ROOT_NAME): + doc_id = node.global_metadata.get(RAG_DOC_ID) + path = node.global_metadata.get(RAG_DOC_PATH) + if doc_id and path: + docs.setdefault(path, doc_id) + return docs + + def _sync_local_dataset(self): + with self._local_monitor_lock: + ids, paths, metadatas = self._list_local_files() + current_docs = dict(zip(paths, ids)) + store_docs = self._list_store_root_docs() + + stale_doc_ids = [doc_id for path, doc_id in store_docs.items() if path not in current_docs] + if stale_doc_ids: + self._delete_doc_from_store(doc_ids=stale_doc_ids) + + pending_paths = [path for path in paths if path not in store_docs] + if pending_paths: + doc_id_map = dict(zip(paths, ids)) + metadata_map = dict(zip(paths, metadatas)) + self._add_doc_to_store( + pending_paths, + [doc_id_map[path] for path in pending_paths], + [metadata_map[path] for path in pending_paths], + ) + self._local_files = set(self._list_store_root_docs().keys()) + + def _monitor_local_dataset_worker(self): + while self._local_monitor_continue: + try: + current_paths = set(self._list_dataset_files()) + if current_paths != self._local_files: + self._sync_local_dataset() + except Exception as exc: # pragma: no cover - defensive + LOG.error(f'Failed to sync local dataset `{self._dataset_path}`: {exc}') + time.sleep(self._local_monitor_interval) + + def _start_local_monitoring(self): + if self._local_monitor_thread and self._local_monitor_thread.is_alive(): + return + self._local_monitor_continue = True + self._local_monitor_thread = threading.Thread(target=self._monitor_local_dataset_worker) + self._local_monitor_thread.daemon = True + self._local_monitor_thread.start() def _delete_doc_from_store(self, doc_ids: List[str] = None) -> None: self._processor.delete_doc(doc_ids=doc_ids) diff --git a/lazyllm/tools/rag/doc_manager.py b/lazyllm/tools/rag/doc_manager.py index 7ef5d8157..214b725b9 100644 --- a/lazyllm/tools/rag/doc_manager.py +++ b/lazyllm/tools/rag/doc_manager.py @@ -7,13 +7,14 @@ from starlette.responses import RedirectResponse import lazyllm -from lazyllm import LOG, FastapiApp as app +from lazyllm import LOG, FastapiApp as app, deprecated from lazyllm.thirdparty import fastapi from .utils import DocListManager, BaseResponse, gen_docid from .global_metadata import RAG_DOC_ID, RAG_DOC_PATH import uuid +@deprecated('lazyllm.tools.rag.doc_service.DocServer') class DocManager(lazyllm.ModuleBase): def __init__(self, dlm: DocListManager) -> None: super().__init__() diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py index 6f6e1e13f..61622bbc5 100644 --- a/lazyllm/tools/rag/doc_service/base.py +++ b/lazyllm/tools/rag/doc_service/base.py @@ -147,10 +147,14 @@ def validate_doc_ids(self): class TransferItem(BaseModel): doc_id: str + target_doc_id: str source_kb_id: str = '__default__' source_algo_id: str = '__default__' target_kb_id: str target_algo_id: str + target_metadata: Optional[Dict[str, Any]] = None + target_filename: Optional[str] = None + target_file_path: Optional[str] = None mode: str = 'copy' diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 07cef5ad4..e522de9a1 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -236,7 +236,8 @@ def set_callback_url(self, callback_url: str): def _ensure_indexes(self): stmts = [ - 'CREATE UNIQUE INDEX IF NOT EXISTS uq_docs_path ON lazyllm_documents(path)', + 'DROP INDEX IF EXISTS uq_docs_path', + 'CREATE INDEX IF NOT EXISTS idx_docs_path ON lazyllm_documents(path)', 'CREATE INDEX IF NOT EXISTS idx_documents_upload_status ON lazyllm_documents(upload_status)', 'CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON lazyllm_documents(updated_at)', 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_display_name ' @@ -876,13 +877,14 @@ def _call_parser_client(self, method, *args, **kwargs): def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, parser_kb_id: Optional[str] = None, - transfer_params: Optional[Dict[str, Any]] = None): + transfer_params: Optional[Dict[str, Any]] = None, + parser_doc_id: Optional[str] = None): if task_type in (TaskType.DOC_ADD, TaskType.DOC_TRANSFER): if not file_path: raise RuntimeError(f'file_path is required for task_type {task_type.value}') task_resp = self._call_parser_client( self._parser_client.add_doc, - task_id, algo_id, parser_kb_id or kb_id, doc_id, file_path, metadata, + task_id, algo_id, parser_kb_id or kb_id, parser_doc_id or doc_id, file_path, metadata, callback_url=self._callback_url, transfer_params=transfer_params, ) elif task_type == TaskType.DOC_REPARSE: @@ -915,7 +917,7 @@ def _enqueue_task( file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, parser_kb_id: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None, - extra_message: Optional[Dict[str, Any]] = None, + extra_message: Optional[Dict[str, Any]] = None, parser_doc_id: Optional[str] = None, ): task_id = str(uuid4()) task_message = { @@ -955,7 +957,7 @@ def _enqueue_task( self._create_parser_task( task_id, doc_id, kb_id, algo_id, task_type, file_path=file_path, metadata=metadata, reparse_group=reparse_group, - parser_kb_id=parser_kb_id, transfer_params=transfer_params, + parser_kb_id=parser_kb_id, transfer_params=transfer_params, parser_doc_id=parser_doc_id, ) except Exception as exc: finished_at = now_ts() @@ -990,6 +992,9 @@ def _apply_doc_upload_status(self, doc_id: str, task_type: TaskType, status: Doc if task_type == TaskType.DOC_ADD: self._set_doc_upload_status(doc_id, status) return + if task_type == TaskType.DOC_TRANSFER: + self._set_doc_upload_status(doc_id, status) + return if task_type == TaskType.DOC_DELETE: if status == DocStatus.DELETING: if self._doc_relation_count(doc_id) <= 1: @@ -1083,7 +1088,7 @@ def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[D prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) return prepared_items - def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: + def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: # noqa: C901 prepared_items = [] seen_pairs = set() seen_targets = set() @@ -1092,21 +1097,33 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An raise DocServiceError( 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} ) - item_key = (item.doc_id, item.source_kb_id, item.target_kb_id) + if item.target_doc_id == item.doc_id: + raise DocServiceError( + 'E_INVALID_PARAM', + 'target_doc_id must be different from source doc_id', + {'doc_id': item.doc_id, 'target_doc_id': item.target_doc_id}, + ) + item_key = (item.doc_id, item.source_kb_id, item.target_kb_id, item.target_doc_id) if item_key in seen_pairs: raise DocServiceError( 'E_INVALID_PARAM', 'duplicate transfer item detected', - {'doc_id': item.doc_id, 'source_kb_id': item.source_kb_id, 'target_kb_id': item.target_kb_id}, + { + 'doc_id': item.doc_id, + 'source_kb_id': item.source_kb_id, + 'target_kb_id': item.target_kb_id, + 'target_doc_id': item.target_doc_id, + }, ) seen_pairs.add(item_key) - target_key = (item.doc_id, item.target_kb_id, item.target_algo_id) + target_key = (item.target_doc_id, item.target_kb_id, item.target_algo_id) if target_key in seen_targets: raise DocServiceError( 'E_INVALID_PARAM', 'duplicate transfer target detected', { 'doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, }, @@ -1128,11 +1145,14 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') - if self._has_kb_document(item.target_kb_id, item.doc_id): + if ( + self._get_doc(item.target_doc_id) is not None + or self._has_kb_document(item.target_kb_id, item.target_doc_id) + ): raise DocServiceError( 'E_STATE_CONFLICT', - f'doc already exists in target kb: {item.doc_id}', - {'doc_id': item.doc_id, 'target_kb_id': item.target_kb_id}, + f'doc already exists in target kb: {item.target_doc_id}', + {'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id}, ) source_snapshot = self._get_parse_snapshot(item.doc_id, item.source_kb_id, item.source_algo_id) if source_snapshot is None or source_snapshot.get('status') != DocStatus.SUCCESS.value: @@ -1146,19 +1166,53 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An 'status': source_snapshot.get('status') if source_snapshot else None, }, ) + source_metadata = _from_json(doc.get('meta')) + target_metadata = dict(source_metadata) + if item.target_metadata: + target_metadata.update(item.target_metadata) + source_path = doc.get('path') + target_path = item.target_file_path + target_filename = item.target_filename + if target_path and target_filename: + resolved_name = os.path.basename(target_path) + if resolved_name and resolved_name != target_filename: + raise DocServiceError( + 'E_INVALID_PARAM', + 'target_filename must match the basename of target_file_path', + { + 'target_doc_id': item.target_doc_id, + 'target_filename': target_filename, + 'target_file_path': target_path, + }, + ) + if target_path and not target_filename: + target_filename = os.path.basename(target_path) or doc.get('filename') + if target_filename and not target_path: + if source_path: + target_path = os.path.join(os.path.dirname(source_path), target_filename) + else: + target_path = target_filename + if not target_filename: + target_filename = doc.get('filename') + if not target_path: + target_path = source_path prepared_items.append({ 'doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, 'source_kb_id': item.source_kb_id, 'source_algo_id': item.source_algo_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, 'mode': item.mode, - 'file_path': doc.get('path'), - 'metadata': _from_json(doc.get('meta')), + 'filename': target_filename, + 'source_type': SourceType(doc.get('source_type')), + 'source_file_path': source_path, + 'target_file_path': target_path, + 'metadata': target_metadata, 'transfer_params': { 'mode': 'mv' if item.mode == 'move' else 'cp', 'target_algo_id': item.target_algo_id, - 'target_doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id, }, }) @@ -1305,15 +1359,26 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: for item in prepared_items: task_id = None try: - self._ensure_kb_document(item['target_kb_id'], item['doc_id']) + self._upsert_doc( + doc_id=item['target_doc_id'], + filename=item['filename'], + path=item['target_file_path'], + metadata=item['metadata'], + source_type=item['source_type'], + upload_status=DocStatus.WAITING, + ) + self._ensure_kb_document(item['target_kb_id'], item['target_doc_id']) task_id, snapshot = self._enqueue_task( - item['doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, idempotency_key=request.idempotency_key, - file_path=item['file_path'], + file_path=item['source_file_path'], metadata=item['metadata'], parser_kb_id=item['source_kb_id'], transfer_params=item['transfer_params'], + parser_doc_id=item['doc_id'], extra_message={ + 'source_doc_id': item['doc_id'], + 'target_doc_id': item['target_doc_id'], 'source_kb_id': item['source_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_kb_id': item['target_kb_id'], @@ -1325,7 +1390,9 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: error_msg = None accepted = True except Exception as exc: - snapshot = self._get_parse_snapshot(item['doc_id'], item['target_kb_id'], item['target_algo_id']) or {} + snapshot = self._get_parse_snapshot( + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'] + ) or {} task_id = task_id or snapshot.get('current_task_id') error_code = snapshot.get('last_error_code') if not error_code: @@ -1334,12 +1401,14 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: accepted = False items.append({ 'doc_id': item['doc_id'], + 'target_doc_id': item['target_doc_id'], 'task_id': task_id, 'source_kb_id': item['source_kb_id'], 'target_kb_id': item['target_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_algo_id': item['target_algo_id'], 'mode': item['mode'], + 'target_file_path': item['target_file_path'], 'status': snapshot.get('status', DocStatus.FAILED.value), 'accepted': accepted, 'error_code': error_code, @@ -1463,6 +1532,8 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 ) if task_type == TaskType.DOC_ADD: self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) + elif task_type == TaskType.DOC_TRANSFER: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) elif task_type == TaskType.DOC_DELETE: self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) return {'ack': True, 'deduped': False, 'ignored_reason': None} @@ -1482,10 +1553,11 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 and task_message.get('mode') == 'move' ): source_kb_id = task_message.get('source_kb_id') - if source_kb_id and source_kb_id != kb_id: - self._remove_kb_document(source_kb_id, doc_id) - self._delete_parse_snapshots(doc_id, source_kb_id) - self._sync_doc_upload_status(doc_id) + source_doc_id = task_message.get('source_doc_id') + if source_kb_id and source_doc_id and source_kb_id != kb_id: + self._remove_kb_document(source_kb_id, source_doc_id) + self._delete_parse_snapshots(source_doc_id, source_kb_id) + self._sync_doc_upload_status(source_doc_id) self._update_task_record( callback.task_id, diff --git a/lazyllm/tools/rag/doc_to_db/extractor.py b/lazyllm/tools/rag/doc_to_db/extractor.py index 8f672d486..c39d3ffab 100644 --- a/lazyllm/tools/rag/doc_to_db/extractor.py +++ b/lazyllm/tools/rag/doc_to_db/extractor.py @@ -18,7 +18,7 @@ from ...sql.sql_manager import DBStatus, SqlManager from ..doc_node import DocNode from ..global_metadata import RAG_DOC_ID, RAG_KB_ID -from ..utils import DocListManager, _orm_to_dict +from ..utils import _orm_to_dict from ..store.store_base import DEFAULT_KB_ID from .model import ( TABLE_SCHEMA_SET_INFO, Table_ALGO_KB_SCHEMA, ExtractionMode, @@ -293,7 +293,7 @@ def has_schema_set(self, schema_set_id: str) -> bool: return True return schema_set_id in self._schema_registry - def register_schema_set_to_kb(self, algo_id: Optional[str] = DocListManager.DEFAULT_GROUP_NAME, + def register_schema_set_to_kb(self, algo_id: Optional[str] = DEFAULT_KB_ID, kb_id: Optional[str] = DEFAULT_KB_ID, schema_set_id: Optional[str] = None, schema_set: Type[BaseModel] = None, force_refresh: bool = False) -> str: ''' @@ -555,7 +555,7 @@ def _validate_extract_params(self, data: Union[str, List[DocNode]], algo_id: str return kb_id, doc_id, _orm_to_dict(bound) def extract_and_store(self, data: Union[str, List[DocNode]], # noqa: C901 - algo_id: str = DocListManager.DEFAULT_GROUP_NAME, + algo_id: str = DEFAULT_KB_ID, schema_set_id: str = None, schema_set: Type[BaseModel] = None) -> ExtractResult: '''Persist extracted fields for a document''' self._lazy_init() @@ -725,7 +725,7 @@ def _get_extract_data(self, algo_id: str, doc_ids: List[str], # noqa: C901 return results def forward(self, data: Union[str, List[DocNode]], - algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult: + algo_id: str = DEFAULT_KB_ID) -> ExtractResult: # NOTE: data should be from single file source (kb_id, doc_id should be the same) self._lazy_init() res = self.extract_and_store(data=data, algo_id=algo_id) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index f46d96b0b..8f292d1fc 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -16,7 +16,7 @@ from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .store.store_base import DEFAULT_KB_ID from .index_base import IndexBase -from .utils import DocListManager, ensure_call_endpoint +from .utils import ensure_call_endpoint from .global_metadata import GlobalMetadataDesc as DocField from .web import DocWebModule import copy @@ -37,6 +37,8 @@ def __instancecheck__(self, __instance): class Document(ModuleBase, BuiltinGroups, metaclass=_MetaDocument): class _Manager(ModuleBase): + DEFAULT_GROUP_NAME = '__default__' + def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None, manager: Union[bool, str] = False, server: Union[bool, int] = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, @@ -44,7 +46,8 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, - doc_server_port: Optional[int] = None): + doc_server_port: Optional[int] = None, + enable_path_monitoring: Optional[bool] = None): super().__init__() self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud @@ -60,13 +63,20 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._embed = self._get_embeds(embed) self._processor = processor self._schema_extractor = self._register_submodules(schema_extractor) - name = name or DocListManager.DEFAULT_GROUP_NAME + self._store_conf = store_conf + self._display_name = display_name + self._description = description + name = name or self.DEFAULT_GROUP_NAME if not display_name: display_name = name - - self._dlm = None if (self._cloud or self._doc_files is not None) else DocListManager( - dataset_path, name, enable_path_monitoring=False if manager else True) + if enable_path_monitoring is None: + enable_path_monitoring = False if manager else True + self._enable_path_monitoring = enable_path_monitoring self._kbs = CallableDict({name: DocImpl( - embed=self._embed, dlm=self._dlm, doc_files=doc_files, global_metadata_desc=doc_fields, + embed=self._embed, + dataset_path=dataset_path, + enable_path_monitoring=enable_path_monitoring, + doc_files=doc_files, + global_metadata_desc=doc_fields, store=store_conf, processor=processor, algo_name=name, display_name=display_name, description=description, schema_extractor=schema_extractor)}) @@ -107,10 +117,20 @@ def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): embed = self._get_embeds(embed) if embed else self._embed schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor - impl = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, global_metadata_desc=doc_fields, - store=store_conf, schema_extractor=schema_extractor) + impl = DocImpl( + dataset_path=self._dataset_path, + embed=embed, + kb_group_name=name, + enable_path_monitoring=self._enable_path_monitoring, + global_metadata_desc=doc_fields, + store=store_conf or self._store_conf, + processor=self._processor, + algo_name=name, + display_name=name, + description=self._description, + schema_extractor=schema_extractor, + ) (self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl - self._dlm.add_kb_group(name=name) def get_doc_by_kb_group(self, name): return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name] @@ -142,7 +162,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal store_conf: Optional[Dict] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, - doc_server_port: Optional[int] = None): + doc_server_port: Optional[int] = None, enable_path_monitoring: Optional[bool] = None): super().__init__() if create_ui: lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') @@ -156,7 +176,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal assert store_conf is None or store_conf['type'] == 'map', ( 'Only map store is supported for Document with temp-files') - name = name or DocListManager.DEFAULT_GROUP_NAME + name = name or Document._Manager.DEFAULT_GROUP_NAME if isinstance(manager, Document._Manager): assert not server, 'Server infomation is already set to by manager' @@ -184,7 +204,8 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal doc_fields, cloud=cloud, doc_files=doc_files, processor=processor, display_name=display_name, description=description, schema_extractor=schema_extractor, - doc_server_port=doc_server_port) + doc_server_port=doc_server_port, + enable_path_monitoring=enable_path_monitoring) self._curr_group = name self._doc_to_db_processor: DocToDbProcessor = None self._graph_document: weakref.ref = None @@ -424,7 +445,7 @@ def __init__(self, url: str, name: str = None): super().__init__() self._missing_keys = set(dir(Document)) - set(dir(UrlDocument)) self._manager = lazyllm.UrlModule(url=ensure_call_endpoint(url)) - self._curr_group = name or DocListManager.DEFAULT_GROUP_NAME + self._curr_group = name or Document._Manager.DEFAULT_GROUP_NAME def _forward(self, func_name: str, *args, **kwargs): args = (self._curr_group, func_name, *args) diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index f06674025..0a2716aee 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import DeclarativeBase, sessionmaker import lazyllm -from lazyllm import config +from lazyllm import config, deprecated from lazyllm.common import override from lazyllm.common.queue import sqlite3_check_threadsafety from lazyllm.thirdparty import tarfile @@ -102,6 +102,7 @@ class DocPathParsingResult(BaseModel): msg: str is_new: bool = False +@deprecated('Document(dataset_path=..., enable_path_monitoring=...)') class DocListManager(ABC): DEFAULT_GROUP_NAME = '__default__' __pool__ = dict() diff --git a/tests/basic_tests/RAG/test_doc_manager.py b/tests/basic_tests/RAG/test_doc_manager.py deleted file mode 100644 index ac44e0b03..000000000 --- a/tests/basic_tests/RAG/test_doc_manager.py +++ /dev/null @@ -1,294 +0,0 @@ -import pytest -import lazyllm -from lazyllm.tools.rag.utils import DocListManager -from lazyllm.tools.rag.doc_manager import DocManager -import shutil -import hashlib -import sqlite3 -import unittest -import requests -import io -import json -import time - - -@pytest.fixture(autouse=True) -def setup_tmpdir(request, tmpdir): - request.cls.tmpdir = tmpdir - - -def get_fid(path): - if isinstance(path, (tuple, list)): - return type(path)(get_fid(p) for p in path) - else: - return hashlib.sha256(f'{path}'.encode()).hexdigest() - - -@pytest.mark.usefixtures("setup_tmpdir") -class TestDocListManager(unittest.TestCase): - - def setUp(self): - self.test_dir = test_dir = self.tmpdir.mkdir("test_documents") - - test_file_1, test_file_2 = test_dir.join("test1.txt"), test_dir.join("test2.txt") - test_file_1.write("This is a test file 1.") - test_file_2.write("This is a test file 2.") - self.test_file_1, self.test_file_2 = str(test_file_1), str(test_file_2) - - self.manager = DocListManager(str(test_dir), "TestManager") - - def tearDown(self): - shutil.rmtree(str(self.test_dir)) - self.manager.release() - - def test_init_tables(self): - self.manager.init_tables() - assert self.manager.table_inited() is True - - def test_add_files(self): - self.manager.init_tables() - - self.manager.add_files([self.test_file_1, self.test_file_2]) - files_list = self.manager.list_files(details=True) - assert len(files_list) == 2 - assert any(self.test_file_1.endswith(row[1]) for row in files_list) - assert any(self.test_file_2.endswith(row[1]) for row in files_list) - - def test_list_kb_group_files(self): - self.manager.init_tables() - # wait for files to be added - time.sleep(15) - files_list = self.manager.list_kb_group_files(DocListManager.DEFAULT_GROUP_NAME, details=True) - assert len(files_list) == 2 - files_list = self.manager.list_kb_group_files('group1', details=True) - assert len(files_list) == 0 - - self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), - DocListManager.DEFAULT_GROUP_NAME) - files_list = self.manager.list_kb_group_files(DocListManager.DEFAULT_GROUP_NAME, details=True) - assert len(files_list) == 2 - - self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), 'group1') - files_list = self.manager.list_kb_group_files('group1', details=True) - assert len(files_list) == 2 - - def test_list_kb_groups(self): - self.manager.init_tables() - assert len(self.manager.list_all_kb_group()) == 1 - - self.manager.add_kb_group('group1') - self.manager.add_kb_group('group2') - r = self.manager.list_all_kb_group() - assert len(r) == 3 and self.manager.DEFAULT_GROUP_NAME in r and 'group2' in r - - def test_delete_files(self): - self.manager.init_tables() - - self.manager.add_files([self.test_file_1, self.test_file_2]) - self.manager.delete_files([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()]) - files_list = self.manager.list_files(details=True) - assert len(files_list) == 2 - files_list = self.manager.list_files(details=True, exclude_status=DocListManager.Status.deleting) - assert len(files_list) == 1 - assert not any(self.test_file_1.endswith(row[1]) for row in files_list) - - def test_add_deleting_file(self): - self.manager.init_tables() - - self.manager.add_files([self.test_file_1, self.test_file_2]) - self.manager.delete_files([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()]) - files_list = self.manager.list_files(details=True) - assert len(files_list) == 2 - files_list = self.manager.list_files(details=True, status=DocListManager.Status.deleting) - assert len(files_list) == 1 - documents = self.manager.add_files([self.test_file_1]) - assert documents == [] - - def test_update_file_message(self): - self.manager.init_tables() - - self.manager.add_files([self.test_file_1]) - file_id = hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest() - self.manager.update_file_message(file_id, meta="New metadata", status="processed") - - conn = sqlite3.connect(self.manager._db_path) - cursor = conn.execute("SELECT meta, status FROM documents WHERE doc_id = ?", (file_id,)) - row = cursor.fetchone() - conn.close() - - assert row[0] == "New metadata" - assert row[1] == "processed" - - def test_get_and_update_file_status(self): - self.manager.init_tables() - - file_id = hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest() - status = self.manager.get_file_status(file_id) - assert status[0] == DocListManager.Status.success - - self.manager.add_files([self.test_file_1], status=DocListManager.Status.waiting) - status = self.manager.get_file_status(file_id) - assert status[0] == DocListManager.Status.success - - self.manager.update_file_status([file_id], DocListManager.Status.waiting) - status = self.manager.get_file_status(file_id) - assert status[0] == DocListManager.Status.waiting - - def test_add_files_to_kb_group(self): - self.manager.init_tables() - files_list = self.manager.list_kb_group_files("group1", details=True) - assert len(files_list) == 0 - - self.manager.add_files([self.test_file_1, self.test_file_2]) - files_list = self.manager.list_kb_group_files("group1", details=True) - assert len(files_list) == 0 - - self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), group="group1") - files_list = self.manager.list_kb_group_files("group1", details=True) - assert len(files_list) == 2 - - def test_delete_files_from_kb_group(self): - self.manager.init_tables() - - self.manager.add_files([self.test_file_1, self.test_file_2]) - self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), group="group1") - - self.manager.delete_files_from_kb_group([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()], "group1") - files_list = self.manager.list_kb_group_files("group1", details=True) - # delete will literally erase the record - assert len(files_list) == 1 - - -@pytest.fixture(scope="class", autouse=True) -def setup_tmpdir_class(request, tmpdir_factory): - request.cls.tmpdir = tmpdir_factory.mktemp("class_tmpdir") - - -@pytest.mark.usefixtures("setup_tmpdir_class") -class TestDocListServer(object): - - @classmethod - def setup_class(cls): - cls.test_dir = test_dir = cls.tmpdir.mkdir("test_server") - - test_file_1, test_file_2 = test_dir.join("test1.txt"), test_dir.join("test2.txt") - test_file_1.write("This is a test file 1.") - test_file_2.write("This is a test file 2.") - cls.test_file_1, cls.test_file_2 = str(test_file_1), str(test_file_2) - - cls.manager = DocListManager(str(test_dir), "TestManager", False) - cls.manager.init_tables() - cls.manager.add_kb_group('group1') - cls.manager.add_kb_group('extra_group') - cls.server = lazyllm.ServerModule(DocManager(cls.manager)) - cls.server.start() - cls._test_inited = True - - test_file_extra = test_dir.join("test_extra.txt") - test_file_extra.write("This is a test file extra.") - cls.test_file_extra = str(test_file_extra) - cls.manager.add_files([cls.test_file_1, cls.test_file_2], status=DocListManager.Status.success) - time.sleep(15) - - def get_url(self, url, **kw): - url = (self.server._url.rsplit("/", 1)[0] + '/' + url).rstrip('/') - if kw: url += ('?' + '&'.join([f'{k}={v}' for k, v in kw.items()])) - return url - - def teardown_class(cls): - cls.server.stop() - shutil.rmtree(str(cls.test_dir)) - cls.manager.release() - - @pytest.mark.order(0) - def test_redirect_to_docs(self): - assert requests.get(self.get_url('')).status_code == 200 - assert requests.get(self.get_url('docs')).status_code == 200 - - @pytest.mark.order(1) - def test_list_kb_groups(self): - response = requests.get(self.get_url('list_kb_groups')) - assert response.status_code == 200 - assert response.json().get('data') == [DocListManager.DEFAULT_GROUP_NAME, 'group1', 'extra_group'] - - @pytest.mark.order(2) - def test_list_files(self): - response = requests.get(self.get_url('list_files')) - assert len(response.json().get('data')) == 2 - response = requests.get(self.get_url('list_files', limit=1)) - assert len(response.json().get('data')) == 1 - response = requests.get(self.get_url('list_files_in_group', group_name=DocListManager.DEFAULT_GROUP_NAME)) - assert len(response.json().get('data')) == 2 - response = requests.get(self.get_url('list_files_in_group', group_name='group1')) - assert len(response.json().get('data')) == 0 - - @pytest.mark.order(3) - def test_upload_files_and_upload_files_to_kb(self): - files = [('files', ('test1.txt', io.BytesIO(b"file1 content"), 'text/plain')), - ('files', ('test2.txt', io.BytesIO(b"file2 content"), 'text/plain'))] - - data = dict(override='true', metadatas=json.dumps([{"key": "value"}, {"key": "value2"}]), user_path='path') - response = requests.post(self.get_url('upload_files', **data), files=files) - assert response.status_code == 200 and response.json().get('code') == 200, response.json() - assert len(response.json().get('data')[0]) == 2 - - response = requests.get(self.get_url('list_files', details=False)) - ids = response.json().get('data') - assert response.status_code == 200 and len(ids) == 4 - - # add_files_to_group - files = [('files', ('test3.txt', io.BytesIO(b"file3 content"), 'text/plain'))] - data = dict(override='false', metadatas=json.dumps([{"key": "value"}]), group_name='group1') - response = requests.post(self.get_url('add_files_to_group', **data), files=files) - assert response.status_code == 200 - - response = requests.get(self.get_url('list_files', details=True)) - assert response.status_code == 200 and len(response.json().get('data')) == 5 - response = requests.get(self.get_url('list_files_in_group', group_name='group1')) - assert response.status_code == 200 and len(response.json().get('data')) == 1 - - @pytest.mark.order(4) - def test_add_files_to_group_and_delete_files_from_group(self): - response = requests.get(self.get_url('list_files', details=False)) - ids = response.json().get('data') - assert response.status_code == 200 and len(ids) == 5 - requests.post(self.get_url('add_files_to_group_by_id'), json=dict(file_ids=ids[:2], group_name='group1')) - response = requests.get(self.get_url('list_files_in_group', group_name='group1')) - assert response.status_code == 200 and len(response.json().get('data')) == 3 - - requests.post(self.get_url('delete_files_from_group'), json=dict(file_ids=ids[:1], group_name='group1')) - response = requests.get(self.get_url('list_files_in_group', group_name='group1')) - assert response.status_code == 200 and len(response.json().get('data')) == 3 - response = requests.get(self.get_url('list_files_in_group', group_name='group1', alive=True)) - assert response.status_code == 200 and len(response.json().get('data')) == 2 - - @pytest.mark.order(5) - def test_delete_files(self): - response = requests.get(self.get_url('list_files', details=False)) - ids = response.json().get('data') - assert response.status_code == 200 and len(ids) == 5 - - response = requests.post(self.get_url('delete_files'), json=dict(file_ids=ids[-1:])) - lazyllm.LOG.warning(response.json()) - assert response.status_code == 200 and response.json().get('code') == 200 - - response = requests.get(self.get_url('list_files')) - assert response.status_code == 200 and len(response.json().get('data')) == 5 - response = requests.get(self.get_url('list_files', alive=True)) - assert response.status_code == 200 and len(response.json().get('data')) == 4 - - response = requests.get(self.get_url('list_files_in_group', group_name='group1')) - assert response.status_code == 200 and len(response.json().get('data')) == 3 - response = requests.get(self.get_url('list_files_in_group', group_name='group1', alive=True)) - assert response.status_code == 200 and len(response.json().get('data')) == 1 - - @pytest.mark.order(6) - def test_add_files(self): - json_data = { - 'files': [self.test_file_extra, "fake path"], - 'group_name': "extra_group", - 'metadatas': json.dumps([{"key": "value"}, {"key": "value"}]) - } - response = requests.post(self.get_url('add_files'), json=json_data) - assert response.status_code == 200 - assert len(response.json().get('data')) == 2 and response.json().get('data')[1] is None diff --git a/tests/basic_tests/RAG/test_doc_service_doc_manager.py b/tests/basic_tests/RAG/test_doc_service_doc_manager.py new file mode 100644 index 000000000..7e92df718 --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_service_doc_manager.py @@ -0,0 +1,559 @@ +import os +import tempfile +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from uuid import uuid4 + +import pytest + +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, + CallbackEventType, + DeleteRequest, + DocServiceError, + DocStatus, + ReparseRequest, + SourceType, + TaskCallbackRequest, + TransferItem, + TransferRequest, + UploadRequest, +) +from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient +from lazyllm.tools.rag.parsing_service.base import TaskType +from lazyllm.tools.rag.utils import BaseResponse + + +class _ManagerHarness: + def __init__(self): + self._tmp_dir = tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_manager_') + self.tmp_dir = self._tmp_dir.name + self.seed_path = self.make_file('seed.txt', 'seed content') + self.db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(self.tmp_dir, 'doc_service_local.db'), + } + self.manager = DocManager(db_config=self.db_config, parser_url='http://parser.test') + self.pending_task_status = {} + self.cancel_calls = [] + self.delete_calls = [] + self.chunk_calls = [] + self.add_doc_calls = [] + self.chunk_response = BaseResponse(code=200, msg='success', data={'items': [], 'total': 0}) + self._patch_parser_client() + + def close(self): + self._tmp_dir.cleanup() + + def make_file(self, name: str, content: str): + file_path = os.path.join(self.tmp_dir, name) + with open(file_path, 'w', encoding='utf-8') as file: + file.write(content) + return file_path + + def finish_task(self, task_id: str, status: DocStatus = DocStatus.SUCCESS, callback_id: str = None): + return self.manager.on_task_callback(TaskCallbackRequest( + callback_id=callback_id or str(uuid4()), + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=status, + )) + + def start_task(self, task_id: str, callback_id: str = None): + return self.manager.on_task_callback(TaskCallbackRequest( + callback_id=callback_id or str(uuid4()), + task_id=task_id, + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + )) + + def _queue_task(self, task_id: str, final_status: DocStatus): + self.pending_task_status[task_id] = final_status + + def _patch_parser_client(self): + def add_doc(task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None, + callback_url=None, transfer_params=None): + self.add_doc_calls.append({ + 'task_id': task_id, + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'file_path': file_path, + 'metadata': metadata, + 'reparse_group': reparse_group, + 'callback_url': callback_url, + 'transfer_params': transfer_params, + }) + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def update_meta(task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None, callback_url=None): + del doc_id, metadata, file_path, callback_url + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def delete_doc(task_id, algo_id, kb_id, doc_id, callback_url=None): + del callback_url + self.delete_calls.append({'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id, 'doc_id': doc_id}) + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def cancel_task(task_id): + self.cancel_calls.append(task_id) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'cancel_status': True}) + + def list_doc_chunks(algo_id, kb_id, doc_id, group, offset, page_size): + self.chunk_calls.append({ + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'group': group, + 'offset': offset, + 'page_size': page_size, + }) + return self.chunk_response + + self.manager._parser_client.add_doc = add_doc + self.manager._parser_client.update_meta = update_meta + self.manager._parser_client.delete_doc = delete_doc + self.manager._parser_client.cancel_task = cancel_task + self.manager._parser_client.list_doc_chunks = list_doc_chunks + self.manager._parser_client.list_algorithms = lambda: BaseResponse( + code=200, + msg='success', + data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + ) + self.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( + code=200, + msg='success', + data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, + ) + + +@pytest.fixture +def manager_harness(): + harness = _ManagerHarness() + try: + yield harness + finally: + harness.close() + + +def test_manager_run_idempotent_atomic(manager_harness): + started = [] + + def handler(): + started.append(time.time()) + time.sleep(0.2) + return {'task_id': str(uuid4())} + + with ThreadPoolExecutor(max_workers=2) as pool: + future = pool.submit(manager_harness.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) + time.sleep(0.05) + with pytest.raises(DocServiceError) as exc: + manager_harness.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + result = future.result(timeout=2) + + assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' + replay = manager_harness.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + assert len(started) == 1 + assert replay == result + + +def test_manager_upload_callback_and_doc_detail(manager_harness): + manager_harness.manager.create_kb('kb_upload', algo_id='__default__') + file_path = manager_harness.make_file('upload.txt', 'upload content') + + items = manager_harness.manager.upload(UploadRequest( + kb_id='kb_upload', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='upload-doc')], + )) + + task_id = items[0]['task_id'] + assert items[0]['accepted'] is True + assert items[0]['parse_status'] == DocStatus.WAITING.value + + manager_harness.finish_task(task_id) + + task = manager_harness.manager.get_task(task_id) + detail = manager_harness.manager.get_doc_detail('upload-doc') + + assert task.code == 200 + assert task.data['status'] == DocStatus.SUCCESS.value + assert detail['doc']['upload_status'] == DocStatus.SUCCESS.value + assert detail['snapshot']['status'] == DocStatus.SUCCESS.value + assert detail['latest_task']['task_id'] == task_id + + +def test_manager_cancel_waiting_add_updates_all_states(manager_harness): + manager_harness.manager.create_kb('kb_cancel', algo_id='__default__') + file_path = manager_harness.make_file('cancel.txt', 'cancel content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_cancel', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='cancel-doc')], + )) + task_id = upload[0]['task_id'] + + resp = manager_harness.manager.cancel_task(task_id) + snapshot = manager_harness.manager._get_parse_snapshot('cancel-doc', 'kb_cancel', '__default__') + doc = manager_harness.manager._get_doc('cancel-doc') + task = manager_harness.manager.get_task(task_id) + + assert resp.code == 200 + assert resp.data['status'] == DocStatus.CANCELED.value + assert task.data['status'] == DocStatus.CANCELED.value + assert snapshot['status'] == DocStatus.CANCELED.value + assert doc['upload_status'] == DocStatus.CANCELED.value + + +def test_manager_cancel_working_task_rejected(manager_harness): + manager_harness.manager.create_kb('kb_working', algo_id='__default__') + file_path = manager_harness.make_file('working.txt', 'working content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_working', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='working-doc')], + )) + task_id = upload[0]['task_id'] + manager_harness.start_task(task_id) + + resp = manager_harness.manager.cancel_task(task_id) + + assert resp.code == 409 + assert resp.data['status'] == DocStatus.WORKING.value + + +def test_manager_delete_waiting_add_uses_cancel_path(manager_harness): + manager_harness.manager.create_kb('kb_delete_waiting', algo_id='__default__') + file_path = manager_harness.make_file('delete_waiting.txt', 'delete waiting content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_delete_waiting', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='delete-waiting-doc')], + )) + original_task_id = upload[0]['task_id'] + + items = manager_harness.manager.delete(DeleteRequest( + kb_id='kb_delete_waiting', + algo_id='__default__', + doc_ids=['delete-waiting-doc'], + )) + snapshot = manager_harness.manager._get_parse_snapshot('delete-waiting-doc', 'kb_delete_waiting', '__default__') + doc = manager_harness.manager._get_doc('delete-waiting-doc') + + assert items[0]['task_id'] == original_task_id + assert items[0]['status'] == DocStatus.CANCELED.value + assert manager_harness.cancel_calls == [original_task_id] + assert manager_harness.delete_calls == [] + assert snapshot['status'] == DocStatus.CANCELED.value + assert doc['upload_status'] == DocStatus.CANCELED.value + + +def test_manager_stale_callback_ignored_after_reparse(manager_harness): + manager_harness.manager.create_kb('kb_stale', algo_id='__default__') + file_path = manager_harness.make_file('stale.txt', 'stale content') + + uploaded = manager_harness.manager.upload(UploadRequest( + kb_id='kb_stale', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='stale-doc')], + )) + manager_harness.finish_task(uploaded[0]['task_id']) + + first_task_id = manager_harness.manager.reparse(ReparseRequest( + kb_id='kb_stale', + algo_id='__default__', + doc_ids=['stale-doc'], + ))[0] + second_task_id = manager_harness.manager.reparse(ReparseRequest( + kb_id='kb_stale', + algo_id='__default__', + doc_ids=['stale-doc'], + ))[0] + + stale_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='stale-callback', + task_id=first_task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + )) + + assert first_task_id != second_task_id + assert stale_resp['ignored_reason'] == 'stale_task_callback' + + +def test_manager_transfer_uses_target_doc_id_for_target_records(manager_harness): + manager_harness.manager.create_kb('kb_transfer_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_transfer_target', algo_id='__default__') + file_path = manager_harness.make_file('transfer.txt', 'transfer content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_transfer_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc', + target_doc_id='target-doc-copy', + source_kb_id='kb_transfer_source', + source_algo_id='__default__', + target_kb_id='kb_transfer_target', + target_algo_id='__default__', + mode='copy', + )])) + + task_id = items[0]['task_id'] + manager_harness.finish_task(task_id) + target_doc = manager_harness.manager._get_doc('target-doc-copy') + + assert items[0]['doc_id'] == 'source-doc' + assert items[0]['target_doc_id'] == 'target-doc-copy' + assert manager_harness.add_doc_calls[-1]['doc_id'] == 'source-doc' + assert manager_harness.add_doc_calls[-1]['transfer_params']['target_doc_id'] == 'target-doc-copy' + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + assert manager_harness.manager._has_kb_document('kb_transfer_target', 'target-doc-copy') is True + assert manager_harness.manager._has_kb_document('kb_transfer_target', 'source-doc') is False + assert target_doc['upload_status'] == DocStatus.SUCCESS.value + assert target_doc['filename'] == 'transfer.txt' + assert target_doc['meta'] == '{}' + assert target_doc['path'] == file_path + assert manager_harness.manager._has_kb_document('kb_transfer_source', 'source-doc') is True + + +def test_manager_transfer_move_cleans_source_doc_with_target_doc_id(manager_harness): + manager_harness.manager.create_kb('kb_move_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_move_target', algo_id='__default__') + file_path = manager_harness.make_file('move.txt', 'move content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_move_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-move')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-move', + target_doc_id='target-doc-move', + source_kb_id='kb_move_source', + source_algo_id='__default__', + target_kb_id='kb_move_target', + target_algo_id='__default__', + mode='move', + )])) + + task_id = items[0]['task_id'] + manager_harness.finish_task(task_id) + + assert manager_harness.manager._has_kb_document('kb_move_source', 'source-doc-move') is False + assert manager_harness.manager._has_kb_document('kb_move_target', 'target-doc-move') is True + assert manager_harness.manager._get_doc('source-doc-move')['upload_status'] == DocStatus.DELETED.value + assert manager_harness.manager._get_doc('target-doc-move')['upload_status'] == DocStatus.SUCCESS.value + assert manager_harness.manager._get_parse_snapshot('source-doc-move', 'kb_move_source', '__default__') is None + + +def test_manager_transfer_target_fields_override_source_defaults(manager_harness): + manager_harness.manager.create_kb('kb_override_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_override_target', algo_id='__default__') + file_path = manager_harness.make_file('override.txt', 'override content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_override_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-override', metadata={'a': 1, 'b': 'keep'})], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-override', + target_doc_id='target-doc-override', + source_kb_id='kb_override_source', + source_algo_id='__default__', + target_kb_id='kb_override_target', + target_algo_id='__default__', + target_metadata={'a': 2, 'c': 'new'}, + target_filename='renamed.txt', + mode='copy', + )])) + + manager_harness.finish_task(items[0]['task_id']) + target_doc = manager_harness.manager._get_doc('target-doc-override') + + assert target_doc['filename'] == 'renamed.txt' + assert target_doc['meta'] == '{"a": 2, "b": "keep", "c": "new"}' + assert os.path.basename(target_doc['path']) == 'renamed.txt' + assert manager_harness.add_doc_calls[-1]['metadata'] == {'a': 2, 'b': 'keep', 'c': 'new'} + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + + +def test_manager_transfer_target_file_path_overrides_target_file_info(manager_harness): + manager_harness.manager.create_kb('kb_path_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_path_target', algo_id='__default__') + file_path = manager_harness.make_file('path-source.txt', 'path source content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_path_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-path')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-path', + target_doc_id='target-doc-path', + source_kb_id='kb_path_source', + source_algo_id='__default__', + target_kb_id='kb_path_target', + target_algo_id='__default__', + target_file_path='/virtual/target/renamed-from-path.txt', + mode='copy', + )])) + + manager_harness.finish_task(items[0]['task_id']) + target_doc = manager_harness.manager._get_doc('target-doc-path') + + assert items[0]['target_file_path'] == '/virtual/target/renamed-from-path.txt' + assert target_doc['filename'] == 'renamed-from-path.txt' + assert target_doc['path'] == '/virtual/target/renamed-from-path.txt' + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + + +def test_manager_list_chunks_forwards_true_pagination(manager_harness): + manager_harness.manager.create_kb('kb_chunks', algo_id='__default__') + file_path = manager_harness.make_file('chunks.txt', 'chunks content') + uploaded = manager_harness.manager.upload(UploadRequest( + kb_id='kb_chunks', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='chunks-doc')], + )) + manager_harness.finish_task(uploaded[0]['task_id']) + manager_harness.chunk_response = BaseResponse(code=200, msg='success', data={ + 'items': [{'uid': 'chunk-2'}, {'uid': 'chunk-3'}], + 'total': 7, + }) + + data = manager_harness.manager.list_chunks( + kb_id='kb_chunks', + doc_id='chunks-doc', + group='line', + algo_id='__default__', + page=2, + page_size=2, + ) + + assert manager_harness.chunk_calls == [{ + 'algo_id': '__default__', + 'kb_id': 'kb_chunks', + 'doc_id': 'chunks-doc', + 'group': 'line', + 'offset': 2, + 'page_size': 2, + }] + assert data['items'] == [{'uid': 'chunk-2'}, {'uid': 'chunk-3'}] + assert data['total'] == 7 + assert data['page'] == 2 + assert data['offset'] == 2 + + +def test_manager_callback_payload_fallback_and_delete_transition(manager_harness): + manager_harness.manager.create_kb('kb_callback', algo_id='__default__') + file_path = manager_harness.make_file('callback.txt', 'callback content') + manager_harness.manager._upsert_doc( + doc_id='callback-doc', + filename='callback.txt', + path=file_path, + metadata={'case': 'callback'}, + source_type=SourceType.EXTERNAL, + ) + manager_harness.manager._ensure_kb_document('kb_callback', 'callback-doc') + queued_at = manager_harness.manager._upsert_parse_snapshot( + doc_id='callback-doc', + kb_id='kb_callback', + algo_id='__default__', + status=DocStatus.DELETING, + task_type=TaskType.DOC_DELETE, + current_task_id='delete-task', + queued_at=datetime.now(), + )['queued_at'] + + start_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='delete-start', + task_id='delete-task', + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'callback-doc', + 'kb_id': 'kb_callback', + 'algo_id': '__default__', + }, + )) + start_snapshot = manager_harness.manager._get_parse_snapshot('callback-doc', 'kb_callback', '__default__') + finish_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='delete-finish', + task_id='delete-task', + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'callback-doc', + 'kb_id': 'kb_callback', + 'algo_id': '__default__', + }, + )) + snapshot = manager_harness.manager._get_parse_snapshot('callback-doc', 'kb_callback', '__default__') + + assert start_resp['ack'] is True + assert finish_resp['ack'] is True + assert start_snapshot['queued_at'] == queued_at + assert snapshot['status'] == DocStatus.DELETED.value + assert manager_harness.manager._has_kb_document('kb_callback', 'callback-doc') is False + assert manager_harness.manager._get_doc('callback-doc')['upload_status'] == DocStatus.DELETED.value + + +def test_parser_client_algo_endpoint_fallback(): + client = _ParserClient(parser_url='http://parser.test') + calls = [] + + def fake_get(path, params=None): + del params + calls.append(path) + if path == '/v1/algo/list': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/list': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + } + if path == '/v1/algo/__default__/groups': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/__default__/group/info': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], + } + raise AssertionError(path) + + client._get = fake_get + + algo_resp = client.list_algorithms() + group_resp = client.get_algorithm_groups('__default__') + + assert algo_resp.code == 200 + assert group_resp.code == 200 + assert calls == [ + '/v1/algo/list', + '/algo/list', + '/v1/algo/__default__/groups', + '/algo/__default__/group/info', + ] diff --git a/tests/basic_tests/RAG/test_doc_service_doc_server.py b/tests/basic_tests/RAG/test_doc_service_doc_server.py new file mode 100644 index 000000000..5ecff01a4 --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_service_doc_server.py @@ -0,0 +1,248 @@ +import asyncio +import json +import os +import tempfile + +import pytest + +from lazyllm.thirdparty import fastapi + +from lazyllm.tools.rag.doc_service.doc_server import DocServer +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, + CallbackEventType, + DocServiceError, + DocStatus, + KbUpdateRequest, + SourceType, + TaskCallbackRequest, + UploadRequest, +) +from lazyllm.tools.rag.utils import BaseResponse + + +class _JsonRequest: + def __init__(self, payload): + self._payload = payload + + async def json(self): + return self._payload + + +class _FormData: + def __init__(self, files): + self._files = files + + def getlist(self, name): + assert name == 'files' + return self._files + + +class _FormRequest: + def __init__(self, files): + self._form = _FormData(files) + + async def form(self): + return self._form + + +class _UploadFile: + def __init__(self, filename: str, content: bytes): + self.filename = filename + self._content = content + + async def read(self): + return self._content + + +class _FakeManager: + def __init__(self): + self.run_calls = [] + self.upload_request = None + self.chunk_kwargs = None + self.cancel_response = BaseResponse( + code=200, msg='success', data={'task_id': 'task-1', 'cancel_status': True, 'status': 'CANCELED'} + ) + + def run_idempotent(self, endpoint, idempotency_key, payload, handler): + self.run_calls.append({ + 'endpoint': endpoint, + 'idempotency_key': idempotency_key, + 'payload': payload, + }) + return handler() + + def upload(self, request: UploadRequest): + self.upload_request = request + return [ + {'doc_id': item.doc_id or 'generated-doc', 'task_id': f'task-{idx}'} + for idx, item in enumerate(request.items) + ] + + def cancel_task(self, task_id: str): + if isinstance(self.cancel_response, Exception): + raise self.cancel_response + resp = self.cancel_response + if isinstance(resp, BaseResponse): + return resp + return BaseResponse.model_validate(resp) + + def list_chunks(self, **kwargs): + self.chunk_kwargs = kwargs + return {'items': [{'uid': 'chunk-1'}], 'total': 1, 'page': kwargs['page'], 'page_size': kwargs['page_size']} + + +def _decode_response(response): + assert isinstance(response, fastapi.responses.JSONResponse) + return json.loads(response.body.decode()) + + +@pytest.fixture +def server_impl(): + with tempfile.TemporaryDirectory(prefix='lazyllm_doc_server_unit_') as temp_dir: + impl = DocServer._Impl(storage_dir=temp_dir, parser_url='http://parser.test') + impl._manager = _FakeManager() + impl._lazy_init = lambda: None + yield impl + + +def test_build_update_kb_payload_distinguishes_omitted_and_null(): + keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') + clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') + + keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) + clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) + + assert keep_payload != clear_payload + assert keep_payload['explicit_fields'] == ['display_name', 'idempotency_key'] + assert clear_payload['explicit_fields'] == ['display_name', 'idempotency_key', 'owner_id'] + + +def test_normalize_task_callback_supports_legacy_fields(): + callback = DocServer._Impl._normalize_task_callback({ + 'task_id': 'task-1', + 'task_type': 'DOC_ADD', + 'doc_id': 'doc-1', + 'kb_id': 'kb-1', + 'algo_id': '__default__', + 'task_status': 'SUCCESS', + }) + + assert isinstance(callback, TaskCallbackRequest) + assert callback.event_type == CallbackEventType.FINISH + assert callback.status == DocStatus.SUCCESS + assert callback.payload == { + 'task_type': 'DOC_ADD', + 'doc_id': 'doc-1', + 'kb_id': 'kb-1', + 'algo_id': '__default__', + } + + +def test_run_wraps_doc_service_error(): + response = DocServer._Impl._response(data={'ok': True}) + assert _decode_response(response)['data'] == {'ok': True} + + impl = DocServer._Impl(storage_dir='.', parser_url='http://parser.test') + wrapped = impl._run(lambda: (_ for _ in ()).throw(DocServiceError('E_INVALID_PARAM', 'bad req', {'x': 1}))) + body = _decode_response(wrapped) + + assert body['code'] == 400 + assert body['msg'] == 'bad req' + assert body['data']['biz_code'] == 'E_INVALID_PARAM' + assert body['data']['x'] == 1 + + +def test_cancel_task_http_requires_task_id(server_impl): + with pytest.raises(fastapi.HTTPException) as exc: + asyncio.run(server_impl.cancel_task(_JsonRequest({}))) + assert exc.value.status_code == 400 + assert exc.value.detail == 'task_id is required' + + +def test_cancel_task_http_maps_conflict(server_impl): + server_impl._manager.cancel_response = BaseResponse( + code=409, + msg='task cannot be canceled', + data={'task_id': 'task-1', 'cancel_status': False, 'status': 'WORKING'}, + ) + + response = asyncio.run(server_impl.cancel_task(_JsonRequest({'task_id': 'task-1', 'idempotency_key': 'idem'}))) + body = _decode_response(response) + + assert server_impl._manager.run_calls[0]['endpoint'] == '/v1/tasks/cancel' + assert server_impl._manager.run_calls[0]['idempotency_key'] == 'idem' + assert body['code'] == 409 + assert body['data']['biz_code'] == 'E_STATE_CONFLICT' + assert body['data']['status'] == 'WORKING' + + +def test_list_chunks_forwards_pagination_to_manager(server_impl): + response = server_impl.list_chunks( + kb_id='kb-1', + doc_id='doc-1', + group='block', + algo_id='algo-1', + page=3, + page_size=4, + offset=8, + ) + body = _decode_response(response) + + assert server_impl._manager.chunk_kwargs == { + 'kb_id': 'kb-1', + 'doc_id': 'doc-1', + 'group': 'block', + 'algo_id': 'algo-1', + 'page': 3, + 'page_size': 4, + 'offset': 8, + } + assert body['data']['items'] == [{'uid': 'chunk-1'}] + assert body['data']['total'] == 1 + + +def test_upload_request_uses_idempotency_payload(server_impl): + file_path = os.path.join(server_impl._storage_dir, 'seed.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write('seed content') + request = UploadRequest( + kb_id='kb-upload', + algo_id='__default__', + source_type=SourceType.API, + idempotency_key='upload-idem', + items=[AddFileItem(file_path=file_path, doc_id='doc-seed')], + ) + + response = server_impl.upload_request(request) + body = _decode_response(response) + + assert server_impl._manager.run_calls[0]['endpoint'] == '/v1/docs/upload' + assert server_impl._manager.run_calls[0]['payload']['items'][0]['doc_id'] == 'doc-seed' + assert body['data']['items'][0]['doc_id'] == 'doc-seed' + + +def test_upload_http_saves_unique_files_and_only_first_doc_id(server_impl): + files = [ + _UploadFile('dup.txt', b'first content'), + _UploadFile('dup.txt', b'second content'), + ] + + response = asyncio.run(server_impl.upload( + _FormRequest(files), + kb_id='kb-upload', + algo_id='__default__', + source_type=SourceType.API, + doc_id='doc-first', + idempotency_key='upload-http-idem', + )) + body = _decode_response(response) + upload_request = server_impl._manager.upload_request + + assert body['code'] == 200 + assert len(upload_request.items) == 2 + assert upload_request.items[0].doc_id == 'doc-first' + assert upload_request.items[1].doc_id is None + assert upload_request.items[0].file_path != upload_request.items[1].file_path + assert os.path.exists(upload_request.items[0].file_path) + assert os.path.exists(upload_request.items[1].file_path) diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py index a7e3cc593..882acbebf 100644 --- a/tests/basic_tests/RAG/test_doc_service_mock.py +++ b/tests/basic_tests/RAG/test_doc_service_mock.py @@ -5,20 +5,11 @@ import tempfile import time from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -from uuid import uuid4 import pytest import requests from lazyllm.tools.rag.doc_service import DocServer -from lazyllm.tools.rag.doc_service.base import ( - AddFileItem, CallbackEventType, DeleteRequest, DocServiceError, DocStatus, KbUpdateRequest, ReparseRequest, - SourceType, TaskCallbackRequest, UploadRequest, -) -from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient -from lazyllm.tools.rag.parsing_service.base import TaskType -from lazyllm.tools.rag.utils import BaseResponse @pytest.mark.skip_on_win @@ -169,6 +160,7 @@ def test_p0_endpoints_and_core_flows(self): 'items': [ { 'doc_id': doc_add, + 'target_doc_id': 'copied-seed-doc-1', 'source_kb_id': 'kb_a', 'source_algo_id': '__default__', 'target_kb_id': 'kb_b', @@ -522,6 +514,7 @@ def test_kb_algo_binding_and_transfer_validation(self): json={ 'items': [{ 'doc_id': doc_id, + 'target_doc_id': 'invalid-transfer-doc', 'source_kb_id': 'kb_bind', 'source_algo_id': '__default__', 'target_kb_id': 'kb_bind', @@ -660,307 +653,3 @@ def test_kb_update_pagination_and_batch_query(self): assert len(batch_data['items']) == 1 assert batch_data['items'][0]['kb_id'] == 'kb_page_1' assert batch_data['missing_kb_ids'] == ['kb_missing'] - - -class TestDocServiceMockLocal: - @classmethod - def setup_class(cls): - cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_local_') - cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') - with open(cls._seed_path, 'w', encoding='utf-8') as f: - f.write('local seed content') - cls._db_config = { - 'db_type': 'sqlite', - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'db_name': os.path.join(cls._tmp_dir, 'doc_service_local.db'), - } - cls.manager = DocManager(db_config=cls._db_config, parser_url='http://parser.test') - cls._pending_task_status = {} - - def _queue_task(task_id: str, final_status: DocStatus): - cls._pending_task_status[task_id] = final_status - - cls.manager._parser_client.add_doc = lambda task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) - cls.manager._parser_client.update_meta = lambda task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) - cls.manager._parser_client.delete_doc = lambda task_id, algo_id, kb_id, doc_id: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) - cls.manager._parser_client.cancel_task = lambda task_id: BaseResponse( - code=200, msg='success', data={'task_id': task_id, 'cancel_status': True} - ) - cls.manager._parser_client.list_algorithms = lambda: BaseResponse( - code=200, msg='success', data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}] - ) - cls.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( - code=200, - msg='success', - data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, - ) - - @classmethod - def teardown_class(cls): - shutil.rmtree(cls._tmp_dir, ignore_errors=True) - - def _wait_task(self, task_id, target_statuses, timeout=8): - deadline = time.time() + timeout - last = None - while time.time() < deadline: - resp = self.manager.get_task(task_id) - assert resp.code == 200 - last = resp.data - if last['status'] in target_statuses: - return last - pending_status = self._pending_task_status.pop(task_id, None) - if pending_status is not None: - self.manager.on_task_callback(TaskCallbackRequest( - task_id=task_id, - event_type=CallbackEventType.FINISH, - status=pending_status, - )) - time.sleep(0.05) - raise AssertionError(f'task {task_id} not finished in time, last={last}') - - def _make_file(self, name: str, content: str): - file_path = os.path.join(self._tmp_dir, name) - with open(file_path, 'w', encoding='utf-8') as f: - f.write(content) - return file_path - - def test_manager_atomic_idempotency(self): - started = [] - - def handler(): - started.append(time.time()) - time.sleep(0.2) - return {'task_id': str(uuid4())} - - with ThreadPoolExecutor(max_workers=2) as pool: - future = pool.submit(self.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) - time.sleep(0.05) - with pytest.raises(DocServiceError) as exc: - self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) - result = future.result(timeout=2) - - assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' - replay = self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) - assert len(started) == 1 - assert replay == result - - def test_manager_kb_algo_binding(self): - self.manager.create_kb('kb_local_bind', algo_id='__default__') - file_path = self._make_file('local_bind.txt', 'local bind content') - with pytest.raises(DocServiceError) as exc: - self.manager.upload(UploadRequest( - kb_id='kb_local_bind', - algo_id='wrong_algo', - items=[AddFileItem(file_path=file_path, doc_id='local-bind-doc')], - )) - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_stale_callback_and_state_conflict(self): - self.manager.create_kb('kb_local_stale', algo_id='__default__') - file_path = self._make_file('local_stale.txt', 'local stale content') - uploaded = self.manager.upload(UploadRequest( - kb_id='kb_local_stale', - algo_id='__default__', - items=[AddFileItem(file_path=file_path, doc_id='local-stale-doc')], - )) - self._wait_task(uploaded[0]['task_id'], {'SUCCESS'}) - first_task_id = self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - ))[0] - second_task_id = self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - ))[0] - stale_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-stale-callback', - task_id=first_task_id, - event_type=CallbackEventType.FINISH, - status=DocStatus.SUCCESS, - )) - assert stale_resp['ignored_reason'] == 'stale_task_callback' - self.manager.delete(DeleteRequest(kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'])) - with pytest.raises(DocServiceError) as exc: - self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - )) - assert exc.value.biz_code == 'E_STATE_CONFLICT' - assert second_task_id != first_task_id - - def test_manager_missing_endpoint_surrogates(self): - self.manager.create_kb('kb_local_info', algo_id='__default__') - file_path = self._make_file('local_info.txt', 'local info content') - uploaded = self.manager.upload(UploadRequest( - kb_id='kb_local_info', - algo_id='__default__', - items=[AddFileItem(file_path=file_path, doc_id='local-info-doc')], - )) - algorithms = self.manager.list_algorithms_compat() - assert len(algorithms['items']) >= 1 - algo_info = self.manager.get_algorithm_info('__default__') - assert algo_info['algo_id'] == '__default__' - chunks = self.manager.list_chunks() - assert chunks['items'] == [] - tasks_batch = self.manager.get_tasks_batch([uploaded[0]['task_id']]) - assert len(tasks_batch['items']) == 1 - - def test_delete_kbs_empty_list_rejected(self): - with pytest.raises(DocServiceError) as exc: - self.manager.delete_kbs([]) - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_rejects_unknown_kb_algorithm(self): - with pytest.raises(DocServiceError) as exc: - self.manager.create_kb('kb_local_unknown_algo', algo_id='missing_algo') - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_update_kb_can_clear_nullable_fields(self): - self.manager.create_kb( - 'kb_local_clearable', - display_name='Clearable', - description='to be cleared', - owner_id='owner-x', - meta={'tag': 'x'}, - algo_id='__default__', - ) - updated = self.manager.update_kb( - 'kb_local_clearable', - display_name=None, - description=None, - owner_id=None, - meta=None, - explicit_fields={'display_name', 'description', 'owner_id', 'meta'}, - ) - assert updated['display_name'] is None - assert updated['description'] is None - assert updated['owner_id'] is None - assert updated['meta'] == {} - - def test_kb_update_idempotency_payload_distinguishes_omitted_and_null(self): - keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') - clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') - - keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) - clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) - - assert keep_payload != clear_payload - - self.manager.run_idempotent( - '/v1/kbs/kb_local_idem:patch', - 'kb-update-idem', - keep_payload, - lambda: {'kb_id': 'kb_local_idem', 'owner_id': 'kept'}, - ) - with pytest.raises(DocServiceError) as exc: - self.manager.run_idempotent( - '/v1/kbs/kb_local_idem:patch', - 'kb-update-idem', - clear_payload, - lambda: {'kb_id': 'kb_local_idem', 'owner_id': None}, - ) - assert exc.value.biz_code == 'E_IDEMPOTENCY_CONFLICT' - - def test_manager_callback_payload_fallback_and_delete_transition(self): - self.manager.create_kb('kb_local_callback', algo_id='__default__') - file_path = self._make_file('local_callback.txt', 'local callback content') - self.manager._upsert_doc( - doc_id='local-callback-doc', - filename='local_callback.txt', - path=file_path, - metadata={'case': 'callback'}, - source_type=SourceType.EXTERNAL, - ) - self.manager._ensure_kb_document('kb_local_callback', 'local-callback-doc') - queued_at = self.manager._upsert_parse_snapshot( - doc_id='local-callback-doc', - kb_id='kb_local_callback', - algo_id='__default__', - status=DocStatus.DELETING, - task_type=TaskType.DOC_DELETE, - current_task_id='local-delete-task', - queued_at=datetime.now(), - )['queued_at'] - - start_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-delete-start', - task_id='local-delete-task', - event_type=CallbackEventType.START, - status=DocStatus.WORKING, - payload={ - 'task_type': TaskType.DOC_DELETE.value, - 'doc_id': 'local-callback-doc', - 'kb_id': 'kb_local_callback', - 'algo_id': '__default__', - }, - )) - assert start_resp['ack'] is True - start_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') - assert start_snapshot['status'] == DocStatus.DELETING.value - assert start_snapshot['queued_at'] == queued_at - - finish_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-delete-finish', - task_id='local-delete-task', - event_type=CallbackEventType.FINISH, - status=DocStatus.SUCCESS, - payload={ - 'task_type': TaskType.DOC_DELETE.value, - 'doc_id': 'local-callback-doc', - 'kb_id': 'kb_local_callback', - 'algo_id': '__default__', - }, - )) - assert finish_resp['ack'] is True - - finish_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') - assert finish_snapshot['status'] == DocStatus.DELETED.value - assert self.manager._has_kb_document('kb_local_callback', 'local-callback-doc') is False - assert self.manager._get_doc('local-callback-doc')['upload_status'] == DocStatus.DELETED.value - - def test_parser_client_algo_endpoint_fallback(self): - client = _ParserClient(parser_url='http://parser.test') - calls = [] - - def fake_get(path, params=None): - del params - calls.append(path) - if path == '/v1/algo/list': - raise RuntimeError('parser http error: 404 missing route') - if path == '/algo/list': - return { - 'code': 200, - 'msg': 'success', - 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], - } - if path == '/v1/algo/__default__/groups': - raise RuntimeError('parser http error: 404 missing route') - if path == '/algo/__default__/group/info': - return { - 'code': 200, - 'msg': 'success', - 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], - } - raise AssertionError(path) - - client._get = fake_get - algo_resp = client.list_algorithms() - group_resp = client.get_algorithm_groups('__default__') - - assert algo_resp.code == 200 - assert group_resp.code == 200 - assert calls == [ - '/v1/algo/list', - '/algo/list', - '/v1/algo/__default__/groups', - '/algo/__default__/group/info', - ] diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 5db8f2319..046805605 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -6,18 +6,11 @@ from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.global_metadata import RAG_DOC_PATH, RAG_DOC_ID from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform -from lazyllm.tools.rag.doc_manager import DocManager -from lazyllm.tools.rag.utils import DocListManager, gen_docid +from lazyllm.tools.rag.utils import gen_docid from lazyllm.launcher import cleanup -from lazyllm import config from unittest.mock import MagicMock import unittest -import httpx import os -import shutil -import io -import re -import json import time import tempfile @@ -41,6 +34,14 @@ def tearDown(self): self.tmp_file_a.close() self.tmp_file_b.close() + def _build_root_nodes(self, input_files): + root_nodes = [] + for path in input_files: + node = DocNode(group=LAZY_ROOT_NAME, text=os.path.basename(path)) + node._global_metadata = {RAG_DOC_ID: gen_docid(path), RAG_DOC_PATH: path} + root_nodes.append(node) + return {LAZY_ROOT_NAME: root_nodes, LAZY_IMAGE_GROUP: []} + def test_create_node_group_default(self): self.doc_impl._create_builtin_node_group('MyChunk', transform=lambda x: ','.split(x)) self.doc_impl._lazy_init() @@ -85,9 +86,69 @@ def test_add_files(self): new_doc = DocNode(text='new dummy text', group=LAZY_ROOT_NAME) new_doc._global_metadata = {RAG_DOC_ID: gen_docid(self.tmp_file_b.name), RAG_DOC_PATH: self.tmp_file_b.name} self.mock_directory_reader.load_data.return_value = {LAZY_ROOT_NAME: [new_doc], LAZY_IMAGE_GROUP: []} - self.doc_impl._processor._add_doc([self.tmp_file_b.name]) + self.doc_impl._processor.add_doc([self.tmp_file_b.name]) assert len(self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME)) == 2 + def test_dataset_path_sync_without_doc_list_manager(self): + self.mock_embed.return_value = [0.1, 0.2, 0.3] + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, 'test.txt') + with open(file_path, 'w') as file: + file.write('local dataset path') + + doc_impl = DocImpl(embed=self.mock_embed, dataset_path=temp_dir, enable_path_monitoring=False) + doc_impl._reader = MagicMock() + doc_impl._reader.load_data.side_effect = ( + lambda input_files, metadatas, split_nodes_by_type=True: self._build_root_nodes(input_files) + ) + doc_impl._lazy_init() + + nodes = doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) + assert len(nodes) == 1 + assert nodes[0].global_metadata[RAG_DOC_ID] == gen_docid(file_path) + assert nodes[0].global_metadata[RAG_DOC_PATH] == file_path + assert not hasattr(doc_impl, '_dlm') + + def test_dataset_path_monitor_adds_and_removes_files(self): + self.mock_embed.return_value = [0.1, 0.2, 0.3] + with tempfile.TemporaryDirectory() as temp_dir: + file_a = os.path.join(temp_dir, 'a.txt') + with open(file_a, 'w') as file: + file.write('a') + + doc_impl = DocImpl(embed=self.mock_embed, dataset_path=temp_dir, enable_path_monitoring=True) + doc_impl._reader = MagicMock() + doc_impl._reader.load_data.side_effect = ( + lambda input_files, metadatas, split_nodes_by_type=True: self._build_root_nodes(input_files) + ) + doc_impl._local_monitor_interval = 0.1 + doc_impl._lazy_init() + + def wait_for_doc_ids(expected_ids): + deadline = time.time() + 3 + while time.time() < deadline: + nodes = doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) + current_ids = {node.global_metadata[RAG_DOC_ID] for node in nodes} + if current_ids == expected_ids: + return + time.sleep(0.1) + raise AssertionError(f'Expected doc ids {expected_ids}, got {current_ids}') + + try: + wait_for_doc_ids({gen_docid(file_a)}) + + file_b = os.path.join(temp_dir, 'b.txt') + with open(file_b, 'w') as file: + file.write('b') + wait_for_doc_ids({gen_docid(file_a), gen_docid(file_b)}) + + os.remove(file_a) + wait_for_doc_ids({gen_docid(file_b)}) + finally: + doc_impl._local_monitor_continue = False + if doc_impl._local_monitor_thread: + doc_impl._local_monitor_thread.join(timeout=1) + class TestDocument(unittest.TestCase): @classmethod def tearDownClass(cls): @@ -267,114 +328,3 @@ def test_get_window_nodes(self): assert len(window) == 5 assert window == sorted(window, key=lambda n: n.number) assert all(n.number in [1, 2, 3, 4, 5] for n in window) - -class TmpDir: - def __init__(self): - self.root_dir = os.path.expanduser(os.path.join(config['home'], 'rag_for_document_ut')) - self.rag_dir = os.path.join(self.root_dir, 'rag_master') - os.makedirs(self.rag_dir, exist_ok=True) - - def __del__(self): - shutil.rmtree(self.root_dir) - -class TestDocumentServer(unittest.TestCase): - def setUp(self): - self.dir = TmpDir() - self.dlm = DocListManager(path=self.dir.rag_dir, name=None, enable_path_monitoring=False) - - self.doc_impl = DocImpl(embed=MagicMock(), dlm=self.dlm) - self.doc_impl._lazy_init() - - doc_manager = DocManager(self.dlm) - self.server = lazyllm.ServerModule(doc_manager) - - self.server.start() - - url_pattern = r'(http://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+)' - self.doc_server_addr = re.findall(url_pattern, self.server._url)[0] - self.time_sleep = 30 - - def test_delete_files_in_store(self): - files = [('files', ('test1.txt', io.BytesIO(b'John\'s house is in Beijing'), 'text/palin')), - ('files', ('test2.txt', io.BytesIO(b'John\'s house is in Shanghai'), 'text/plain'))] - metadatas = [{'comment': 'comment1'}, {'signature': 'signature2'}] - params = dict(override='true', metadatas=json.dumps(metadatas)) - - url = f'{self.doc_server_addr}/upload_files' - response = httpx.post(url, params=params, files=files, timeout=10) - assert response.status_code == 200 and response.json().get('code') == 200, response.json() - ids = response.json().get('data')[0] - lazyllm.LOG.info(f'debug!!! ids -> {ids}') - assert len(ids) == 2 - - time.sleep(self.time_sleep) # waiting for worker thread to update newly uploaded files - - # make sure that ids are written into the store - nodes = self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) - doc_ids = [] - doc_file_paths = [] - doc_metadatas = [] - test1_docid = None - test2_docid = None - for node in nodes: - doc_ids.append(node.global_metadata[RAG_DOC_ID]) - doc_file_paths.append(node.global_metadata.get(RAG_DOC_PATH, '')) - doc_metadatas.append(node.global_metadata) - if 'test1' in node.global_metadata.get(RAG_DOC_PATH, ''): - test1_docid = node.global_metadata[RAG_DOC_ID] - elif 'test2' in node.global_metadata.get(RAG_DOC_PATH, ''): - test2_docid = node.global_metadata[RAG_DOC_ID] - lazyllm.LOG.info(f'debug!!! doc_ids -> {doc_ids}\n') - lazyllm.LOG.info(f'debug!!! doc_file_paths -> {doc_file_paths}\n') - lazyllm.LOG.info(f'debug!!! doc_metadatas -> {doc_metadatas}\n') - assert test1_docid and test2_docid - assert set(doc_ids) == set(ids) - - url = f'{self.doc_server_addr}/delete_files' - response = httpx.post(url, json=dict(file_ids=[test1_docid])) - assert response.status_code == 200 and response.json().get('code') == 200 - - time.sleep(self.time_sleep) # waiting for worker thread to delete files - - nodes = self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) - assert len(nodes) == 1 - assert nodes[0].global_metadata[RAG_DOC_ID] == test2_docid - cur_meta_dict = nodes[0].global_metadata - - url = f'{self.doc_server_addr}/add_metadata' - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'title2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - lazyllm.LOG.info(f'debug!!! cur_meta_dict -> {cur_meta_dict}\n') - assert cur_meta_dict['title'] == 'title2' - - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'TITLE2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - lazyllm.LOG.info(f'debug!!! cur_meta_dict -> {cur_meta_dict}\n') - assert cur_meta_dict['title'] == ['title2', 'TITLE2'] - - url = f'{self.doc_server_addr}/delete_metadata_item' - - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'TITLE2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - assert cur_meta_dict['title'] == ['title2'] - - url = f'{self.doc_server_addr}/reset_metadata' - response = httpx.post(url, json=dict(doc_ids=[test2_docid], - new_meta={'author': 'author2', 'signature': 'signature_new'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - assert cur_meta_dict['signature'] == 'signature_new' and cur_meta_dict['author'] == 'author2' - - url = f'{self.doc_server_addr}/query_metadata' - response = httpx.post(url, json=dict(doc_id=test2_docid)) - - # make sure that only one file is left - response = httpx.get(f'{self.doc_server_addr}/list_files') - assert response.status_code == 200 and len(response.json().get('data')) == 1 - - def tearDown(self): - # Must clean up the server as all uploaded files will be deleted as they are in tmp dir - self.dlm.release() From 230fdd03e137426ec0c884cb06358ad14ca8e942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 24 Mar 2026 20:02:50 +0800 Subject: [PATCH 31/46] fix --- docs/en/API Reference/tools.md | 53 +++--- docs/zh/API Reference/tools.md | 53 +++--- lazyllm/components/deploy/relay/server.py | 43 +++-- lazyllm/docs/tools/tool_rag.py | 170 ++++++++++++++++++-- lazyllm/tools/rag/doc_service/doc_server.py | 3 +- lazyllm/tools/rag/document.py | 12 +- lazyllm/tools/rag/parsing_service/impl.py | 2 - lazyllm/tools/rag/parsing_service/server.py | 4 +- lazyllm/tools/rag/web.py | 7 +- tests/basic_tests/RAG/test_document.py | 94 ++++++++--- 10 files changed, 327 insertions(+), 114 deletions(-) diff --git a/docs/en/API Reference/tools.md b/docs/en/API Reference/tools.md index 3be3ecb24..e31bdf5e1 100644 --- a/docs/en/API Reference/tools.md +++ b/docs/en/API Reference/tools.md @@ -200,35 +200,29 @@ members: [find] exclude-members: -::: lazyllm.tools.rag.DocManager - members: document, list_kb_groups, add_files, reparse_files - exclude-members: +::: lazyllm.tools.rag.doc_service.DocServer + members: [upload, add, reparse, delete, transfer, patch_metadata, list_docs, get_doc, list_tasks, get_task, cancel_task, list_kbs, get_kb, list_chunks, list_algorithms, get_algorithm_info, create_kb, update_kb, batch_get_kbs, delete_kb, delete_kbs] + exclude-members: -::: lazyllm.tools.rag.utils.SqliteDocListManager - members: - - table_inited - - get_status_cond_and_params - - validate_paths - - update_need_reparsing - - list_files - - get_docs - - set_docs_new_meta - - fetch_docs_changed_meta - - list_all_kb_group - - add_kb_group - - list_kb_group_files - - delete_unreferenced_doc - - get_docs_need_reparse - - get_existing_paths_by_pattern - - update_file_message - - update_file_status - - add_files_to_kb_group - - delete_files_from_kb_group - - get_file_status - - update_kb_group - - release - - get_status_cond_and_params - exclude-members: +::: lazyllm.tools.rag.doc_service.base.AddFileItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.UploadRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.AddRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferRequest + members: + exclude-members: ::: lazyllm.tools.rag.data_loaders.DirectoryReader members: load_data @@ -425,9 +419,6 @@ ::: lazyllm.tools.rag.rerank.ModuleReranker members: forward exclude-members: -::: lazyllm.tools.rag.utils.DocListManager - members: - exclude-members: ::: lazyllm.tools.rag.global_metadata.GlobalMetadataDesc members: exclude-members: diff --git a/docs/zh/API Reference/tools.md b/docs/zh/API Reference/tools.md index 1036a2fe4..90c13a0cf 100644 --- a/docs/zh/API Reference/tools.md +++ b/docs/zh/API Reference/tools.md @@ -189,35 +189,29 @@ members: [find] exclude-members: -::: lazyllm.tools.rag.DocManager - members: document, list_kb_groups, add_files, reparse_files - exclude-members: +::: lazyllm.tools.rag.doc_service.DocServer + members: [upload, add, reparse, delete, transfer, patch_metadata, list_docs, get_doc, list_tasks, get_task, cancel_task, list_kbs, get_kb, list_chunks, list_algorithms, get_algorithm_info, create_kb, update_kb, batch_get_kbs, delete_kb, delete_kbs] + exclude-members: -::: lazyllm.tools.rag.utils.SqliteDocListManager - members: - - table_inited - - get_status_cond_and_params - - validate_paths - - update_need_reparsing - - list_files - - get_docs - - set_docs_new_meta - - fetch_docs_changed_meta - - list_all_kb_group - - add_kb_group - - list_kb_group_files - - delete_unreferenced_doc - - get_docs_need_reparse - - get_existing_paths_by_pattern - - update_file_message - - update_file_status - - add_files_to_kb_group - - delete_files_from_kb_group - - get_file_status - - update_kb_group - - release - - get_status_cond_and_params - exclude-members: +::: lazyllm.tools.rag.doc_service.base.AddFileItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.UploadRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.AddRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferRequest + members: + exclude-members: ::: lazyllm.tools.rag.data_loaders.DirectoryReader members: load_data @@ -414,9 +408,6 @@ ::: lazyllm.tools.rag.rerank.ModuleReranker members: forward exclude-members: -::: lazyllm.tools.rag.utils.DocListManager - members: - exclude-members: ::: lazyllm.tools.rag.global_metadata.GlobalMetadataDesc members: exclude-members: diff --git a/lazyllm/components/deploy/relay/server.py b/lazyllm/components/deploy/relay/server.py index dad7e3021..8bfadc251 100644 --- a/lazyllm/components/deploy/relay/server.py +++ b/lazyllm/components/deploy/relay/server.py @@ -1,15 +1,9 @@ -from lazyllm.common.utils import str2obj -import uvicorn import argparse import os import sys import inspect import traceback from types import GeneratorType -import lazyllm -from lazyllm import kwargs, package, load_obj -from lazyllm import FastapiApp, globals -from lazyllm.common import _trim_traceback, _register_trim_module import time import pickle import codecs @@ -18,9 +12,35 @@ from functools import partial from typing import Callable -from fastapi import FastAPI, Request -from fastapi.responses import Response, StreamingResponse -import requests + +def _inject_pythonpath(argv): + pythonpath = None + for index, arg in enumerate(argv): + if arg == '--pythonpath' and index + 1 < len(argv): + pythonpath = argv[index + 1] + break + if arg.startswith('--pythonpath='): + pythonpath = arg.split('=', 1)[1] + break + if pythonpath: + pythonpath = os.path.abspath(pythonpath) + if pythonpath in sys.path: + sys.path.remove(pythonpath) + sys.path.insert(0, pythonpath) + + +_inject_pythonpath(sys.argv[1:]) + +from lazyllm.common.utils import str2obj # noqa: E402 +import uvicorn # noqa: E402 +import lazyllm # noqa: E402 +from lazyllm import kwargs, package, load_obj # noqa: E402 +from lazyllm import FastapiApp, globals # noqa: E402 +from lazyllm.common import _trim_traceback, _register_trim_module # noqa: E402 + +from fastapi import FastAPI, Request # noqa: E402 +from fastapi.responses import Response, StreamingResponse # noqa: E402 +import requests # noqa: E402 # TODO(sunxiaoye): delete in the future lazyllm_module_dir = os.path.abspath(__file__) @@ -43,7 +63,10 @@ args = parser.parse_args() if args.pythonpath: - sys.path.append(args.pythonpath) + pythonpath = os.path.abspath(args.pythonpath) + if pythonpath in sys.path: + sys.path.remove(pythonpath) + sys.path.insert(0, pythonpath) func = load_obj(args.function) if args.before_function: diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index 69b43bb5d..4e812f294 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -5,6 +5,18 @@ add_chinese_doc = functools.partial(utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools')) add_english_doc = functools.partial(utils.add_english_doc, module=importlib.import_module('lazyllm.tools')) add_example = functools.partial(utils.add_example, module=importlib.import_module('lazyllm.tools')) +add_doc_service_chinese_doc = functools.partial( + utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service') +) +add_doc_service_english_doc = functools.partial( + utils.add_english_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service') +) +add_doc_service_base_chinese_doc = functools.partial( + utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service.base') +) +add_doc_service_base_english_doc = functools.partial( + utils.add_english_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service.base') +) add_english_doc('Document', '''\ Initialize a document management module with optional embedding, storage, and user interface. @@ -14,17 +26,19 @@ Args: dataset_path (Optional[str]): Path to the dataset directory. If not found, the system will attempt to locate it in ``lazyllm.config["data_path"]``. embed (Optional[Union[Callable, Dict[str, Callable]]]): Embedding function or mapping of embedding functions. When a dictionary is provided, keys are embedding names and values are embedding models. - manager (Union[bool, str], optional): Whether to enable the document manager. If ``True``, launches a manager service. If ``'ui'``, also enables the document management web UI. Defaults to ``False``. + create_ui (bool, optional): Deprecated alias of ``manager`` kept for compatibility. + manager (Union[bool, str], optional): Whether to enable the document manager. If ``True``, launches ``DocServer`` together with a local parsing service. If ``'ui'``, also enables the document management web UI. server (Union[bool, int], optional): Whether to run a server interface for knowledge bases. ``True`` enables a default server, an integer specifies a custom port, and ``False`` disables it. Defaults to ``False``. name (Optional[str]): Name identifier for this document collection. Defaults to the system default name. launcher (Optional[Launcher]): Launcher instance for managing server processes. Defaults to a remote asynchronous launcher. - store_conf (Optional[Dict]): Storage configuration. Defaults to in-memory MapStore. - doc_fields (Optional[Dict[str, DocField]]): Metadata field configuration for storing and retrieving document attributes. - cloud (bool): Whether the dataset is stored in the cloud. Defaults to ``False``. doc_files (Optional[List[str]]): Temporary document files. When used, ``dataset_path`` must be ``None``. Only MapStore is supported in this mode. - processor (Optional[DocumentProcessor]): Document processing service. + doc_fields (Optional[Dict[str, DocField]]): Metadata field configuration for storing and retrieving document attributes. + store_conf (Optional[Dict]): Storage configuration. Defaults to in-memory MapStore. display_name (Optional[str]): Human-readable display name for this document module. Defaults to the collection name. description (Optional[str]): Description of the document collection. Defaults to ``"algorithm description"``. + schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): Optional schema extractor used for metadata schema analysis and registration. + doc_server_port (Optional[int]): Explicit local port for ``DocServer`` when ``manager`` is enabled. + enable_path_monitoring (Optional[bool]): Whether to watch the local dataset path for file additions and removals. Defaults to enabled for local documents without manager mode. ''') add_chinese_doc('Document', '''\ @@ -35,17 +49,153 @@ Args: dataset_path (Optional[str]): 数据集目录路径。如果路径不存在,系统会尝试在 ``lazyllm.config["data_path"]`` 中查找。 embed (Optional[Union[Callable, Dict[str, Callable]]]): 文档向量化函数或函数字典。若为字典,键为 embedding 名称,值为对应的模型。 - manager (Union[bool, str], optional): 是否启用文档管理服务。``True`` 表示启动管理服务;``'ui'`` 表示同时启动 Web 管理界面;默认 ``False``。 + create_ui (bool, optional): ``manager`` 的兼容别名,已废弃。 + manager (Union[bool, str], optional): 是否启用文档管理服务。``True`` 表示启动 ``DocServer`` 及其本地 parsing service;``'ui'`` 表示同时启动 Web 管理界面。 server (Union[bool, int], optional): 是否为知识库运行服务接口。``True`` 表示启动默认服务;整型数值表示自定义端口;``False`` 表示关闭。默认为 ``False``。 name (Optional[str]): 文档集合的名称标识符。默认为系统默认名称。 launcher (Optional[Launcher]): 启动器实例,用于管理服务进程。默认使用远程异步启动器。 - store_conf (Optional[Dict]): 存储配置。默认使用内存中的 MapStore。 - doc_fields (Optional[Dict[str, DocField]]): 元数据字段配置,用于存储和检索文档属性。 - cloud (bool): 是否为云端数据集。默认为 ``False``。 doc_files (Optional[List[str]]): 临时文档文件列表。当使用此参数时,``dataset_path`` 必须为 ``None``,且仅支持 MapStore。 - processor (Optional[DocumentProcessor]): 文档处理服务。 + doc_fields (Optional[Dict[str, DocField]]): 元数据字段配置,用于存储和检索文档属性。 + store_conf (Optional[Dict]): 存储配置。默认使用内存中的 MapStore。 display_name (Optional[str]): 文档模块的可读显示名称。默认为集合名称。 description (Optional[str]): 文档集合的描述。默认为 ``"algorithm description"``。 + schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): 可选 schema extractor,用于元数据 schema 分析与注册。 + doc_server_port (Optional[int]): ``manager`` 启用时 ``DocServer`` 使用的本地端口。 + enable_path_monitoring (Optional[bool]): 是否监控本地数据目录的文件新增和删除。对非 manager 的本地文档默认开启。 +''') + +add_doc_service_english_doc('DocServer', '''\ +Primary entry point of the refactored document service. + +``DocServer`` manages document upload/add/reparse/delete flows, task tracking, knowledge-base management, +chunk inspection, and cross-kb transfer. It is the recommended replacement for the legacy ``DocManager`` / +``DocListManager`` APIs. + +Args: + port (Optional[int]): Local service port when starting an in-process server. + url (Optional[str]): Existing doc_service URL. When provided, the instance works as a remote client. + parser_url (Optional[str]): Parsing service URL used by the local doc_service instance. + db_config (Optional[Dict[str, Any]]): Metadata database configuration for doc_service. + parser_db_config (Optional[Dict[str, Any]]): Parsing task database configuration for the parsing service. + parser_poll_interval (float): Poll interval used by local parser coordination. + storage_dir (Optional[str]): Local storage directory for uploaded files. + callback_url (Optional[str]): Callback URL used to receive parsing task updates. + launcher: Launcher used to start local services. +''') + +add_doc_service_chinese_doc('DocServer', '''\ +重构后文档服务的主入口。 + +``DocServer`` 负责文档上传/添加/重解析/删除、任务跟踪、知识库管理、chunk 查看,以及跨知识库文档转移。 +它是 legacy ``DocManager`` / ``DocListManager`` API 的推荐替代方案。 + +Args: + port (Optional[int]): 本地启动服务时使用的端口。 + url (Optional[str]): 已存在的 doc_service 地址;提供后当前实例作为远程客户端使用。 + parser_url (Optional[str]): 本地 doc_service 使用的 parsing service 地址。 + db_config (Optional[Dict[str, Any]]): doc_service 元数据数据库配置。 + parser_db_config (Optional[Dict[str, Any]]): parsing service 任务数据库配置。 + parser_poll_interval (float): 本地解析协调使用的轮询间隔。 + storage_dir (Optional[str]): 上传文件保存目录。 + callback_url (Optional[str]): 接收解析任务回调的地址。 + launcher: 本地服务启动器。 +''') + +add_doc_service_english_doc('DocServer.list_chunks', '''\ +List parsed chunks for a document through the ``/v1/chunks`` endpoint. + +Args: + kb_id (str): Knowledge-base ID. + doc_id (str): Source document ID. + group (str): Node group name to inspect. + algo_id (str): Algorithm ID. + page (int): 1-based page number. + page_size (int): Number of chunks per page. + offset (Optional[int]): Explicit offset. When omitted, the service derives it from ``page`` and ``page_size``. + +Returns: + Paginated chunk data including ``items`` and ``total``. +''') + +add_doc_service_chinese_doc('DocServer.list_chunks', '''\ +通过 ``/v1/chunks`` 接口分页查看文档的解析 chunk。 + +Args: + kb_id (str): 知识库 ID。 + doc_id (str): 文档 ID。 + group (str): 要查看的节点组名。 + algo_id (str): 算法 ID。 + page (int): 从 1 开始的页码。 + page_size (int): 每页 chunk 数量。 + offset (Optional[int]): 显式偏移量;未传时服务端会根据 ``page`` 和 ``page_size`` 推导。 + +Returns: + 包含 ``items`` 与 ``total`` 的分页结果。 +''') + +add_doc_service_english_doc('DocServer.transfer', '''\ +Transfer parsed documents between knowledge bases under the same algorithm. + +The request body is a ``TransferRequest``. Each transfer item must provide a unique ``target_doc_id`` in the target +knowledge base. Transfer across different algorithms is not supported. Optional ``target_filename`` and +``target_file_path`` can override the destination file name/path recorded for the transferred document. +''') + +add_doc_service_chinese_doc('DocServer.transfer', '''\ +在同一算法下的不同知识库之间转移已解析文档。 + +请求体为 ``TransferRequest``。每个转移项都必须在目标知识库中提供唯一的 ``target_doc_id``。 +当前不支持跨算法 transfer。可选字段 ``target_filename`` 与 ``target_file_path`` 用于覆盖目标文档记录的文件名或文件路径。 +''') + +add_doc_service_base_english_doc('TransferItem', '''\ +Single item in a document transfer request. + +Args: + doc_id (str): Source document ID. + target_doc_id (str): Required destination document ID. Must be unique in the target knowledge base. + source_kb_id (str): Source knowledge-base ID. + source_algo_id (str): Source algorithm ID. + target_kb_id (str): Destination knowledge-base ID. + target_algo_id (str): Destination algorithm ID. + target_metadata (Optional[Dict[str, Any]]): Metadata patch applied on top of the source document metadata for the transferred target document. + target_filename (Optional[str]): Target file name override. + target_file_path (Optional[str]): Target file path override. If set together with ``target_filename``, both must + point to the same basename. + mode (str): Transfer mode. Supports ``copy`` and ``move``. +''') + +add_doc_service_base_chinese_doc('TransferItem', '''\ +文档转移请求中的单个条目。 + +Args: + doc_id (str): 源文档 ID。 + target_doc_id (str): 必填的目标文档 ID,在目标知识库中必须唯一。 + source_kb_id (str): 源知识库 ID。 + source_algo_id (str): 源算法 ID。 + target_kb_id (str): 目标知识库 ID。 + target_algo_id (str): 目标算法 ID。 + target_metadata (Optional[Dict[str, Any]]): 基于源文档 metadata 做继承后,再覆盖写入目标文档的 metadata patch。 + target_filename (Optional[str]): 目标文件名覆盖值。 + target_file_path (Optional[str]): 目标文件路径覆盖值;若与 ``target_filename`` 同时传入,二者 basename + 必须一致。 + mode (str): 转移模式,支持 ``copy`` 与 ``move``。 +''') + +add_doc_service_base_english_doc('TransferRequest', '''\ +Batch transfer request for ``DocServer.transfer``. + +Args: + items (List[TransferItem]): Transfer items to execute. + idempotency_key (Optional[str]): Optional idempotency key for safe retries. +''') + +add_doc_service_base_chinese_doc('TransferRequest', '''\ +``DocServer.transfer`` 使用的批量转移请求。 + +Args: + items (List[TransferItem]): 要执行的转移条目列表。 + idempotency_key (Optional[str]): 可选幂等键,用于安全重试。 ''') add_example('Document', '''\ diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index ce928ebe1..e6fac297e 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -599,6 +599,7 @@ def __init__( parser_poll_interval: float = 0.05, storage_dir: Optional[str] = None, callback_url: Optional[str] = None, + pythonpath: Optional[str] = None, launcher=None, ): super().__init__() @@ -619,7 +620,7 @@ def __init__( parser_url=parser_url, callback_url=callback_url, ) - self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher, pythonpath=pythonpath) @staticmethod def _register_openapi_routes(openapi_app: 'fastapi.FastAPI', impl: 'DocServer._Impl'): diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 8f292d1fc..411e646d2 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -23,6 +23,8 @@ import functools import weakref +_LOCAL_PYTHONPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + class CallableDict(dict): def __call__(self, cls, *args, **kw): @@ -81,7 +83,15 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, description=description, schema_extractor=schema_extractor)}) if manager: - self._manager = DocServer(launcher=self._launcher, storage_dir=dataset_path, port=doc_server_port) + self._doc_processor = DocumentProcessor(launcher=self._launcher, pythonpath=_LOCAL_PYTHONPATH) + self._doc_processor.start() + self._manager = DocServer( + launcher=self._launcher, + storage_dir=dataset_path, + port=doc_server_port, + parser_url=self._doc_processor.url, + pythonpath=_LOCAL_PYTHONPATH, + ) if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager) if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server))) self._global_metadata_desc = doc_fields diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index a7e0a5929..6c52deae8 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -5,8 +5,6 @@ from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from functools import cached_property -from itertools import repeat - from lazyllm import LOG from ..data_loaders import DirectoryReader diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index d8ee1f307..43ecd853e 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -832,7 +832,7 @@ def __init__(self, port: int = None, url: str = None, num_workers: int = 1, launcher: Optional[Launcher] = None, post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, callback_url: Optional[str] = None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, - high_priority_workers: int = 1): + high_priority_workers: int = 1, pythonpath: Optional[str] = None): super().__init__() self._raw_impl = None # save the reference of the original Impl object self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') @@ -849,7 +849,7 @@ def __init__(self, port: int = None, url: str = None, num_workers: int = 1, high_priority_workers=high_priority_workers, callback_url=callback_url ) - self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher, pythonpath=pythonpath) else: self._impl = UrlModule(url=ensure_call_endpoint(url)) diff --git a/lazyllm/tools/rag/web.py b/lazyllm/tools/rag/web.py index e0a7768fc..102828f03 100644 --- a/lazyllm/tools/rag/web.py +++ b/lazyllm/tools/rag/web.py @@ -240,10 +240,11 @@ def wait(self): self.demo.block_thread() def stop(self): - if self.demo: - self.demo.close() + demo = self.__dict__.get('demo') + if demo: + demo.close() del self.demo - self.demo, self.url = None, '' + self.url = '' def _find_can_use_network_port(self): for port in self.port: diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 046805605..2f06df6ba 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -6,11 +6,13 @@ from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.global_metadata import RAG_DOC_PATH, RAG_DOC_ID from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform +import lazyllm.tools.rag.document as document_module from lazyllm.tools.rag.utils import gen_docid from lazyllm.launcher import cleanup -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import unittest import os +import shutil import time import tempfile @@ -154,11 +156,30 @@ class TestDocument(unittest.TestCase): def tearDownClass(cls): cleanup() + def setUp(self): + self._temp_dirs = [] + + def tearDown(self): + for temp_dir in self._temp_dirs: + shutil.rmtree(temp_dir, ignore_errors=True) + + def _build_dataset(self, text: str = None) -> str: + temp_dir = tempfile.mkdtemp(prefix='lazyllm_document_') + self._temp_dirs.append(temp_dir) + file_path = os.path.join(temp_dir, 'rag.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write(text or '\n'.join( + f'第{i}段:何为天道?人法地,地法天,天法道,道法自然。什么是道?道可道,非常道。' + for i in range(1, 241) + )) + return temp_dir + def test_register_global_and_local(self): Document.create_node_group('Chunk1', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) Document.create_node_group('Chunk2', transform=TransformArgs( f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) - doc1, doc2 = Document('rag_master'), Document('rag_master') + dataset_path = self._build_dataset() + doc1, doc2 = Document(dataset_path), Document(dataset_path) doc2.create_node_group('Chunk2', transform=dict( f=SentenceSplitter, kwargs=dict(chunk_size=128, chunk_overlap=10))) doc2.create_node_group('Chunk3', trans_node=True, @@ -186,8 +207,9 @@ def test_register_global_and_local(self): assert isinstance(r[0], DocNode) def test_create_document(self): - Document('rag_master') - Document('rag_master/') + dataset_path = self._build_dataset() + Document(dataset_path) + Document(dataset_path + os.sep) def test_register_with_pattern(self): Document.create_node_group('AdaptiveChunk1', transform=[ @@ -197,7 +219,7 @@ def test_register_with_pattern(self): dict(f=SentenceSplitter, pattern=(lambda x: x.endswith('.txt')), kwargs=dict(chunk_size=512, chunk_overlap=50)), TransformArgs(f=SentenceSplitter, pattern=None, kwargs=dict(chunk_size=256, chunk_overlap=25))])) - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc._impl._lazy_init() retriever = Retriever(doc, 'AdaptiveChunk1', similarity='bm25', topk=2) retriever('什么是道') @@ -205,7 +227,7 @@ def test_register_with_pattern(self): retriever('什么是道') def test_create_node_group_with_ref(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) # Create parent node group doc.create_node_group('parent_chunk', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) # Create ref node group under parent @@ -230,7 +252,7 @@ def transform_with_ref(text, ref): assert 'doc_test_ref' in ref_nodes[0].text def test_create_node_group_with_invalid_ref(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) # Create two independent node groups doc.create_node_group('group_a', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) doc.create_node_group('group_b', transform=SentenceSplitter, chunk_size=256, chunk_overlap=25) @@ -251,7 +273,7 @@ def test_find(self): # root --- CoarseChunk < /- chunk21 # \ \- chunk2 < # \- FineChunk \- chunk22 - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) doc.create_node_group('chunk11', parent='chunk1', @@ -276,22 +298,48 @@ def _test_impl(group, target): _test_impl(group, target) def test_doc_web_module(self): - import time - import requests - doc = Document('rag_master', manager='ui') - doc.create_kb_group(name='test_group') - doc2 = Document('rag_master', manager=doc.manager, name='test_group2') - doc.start() - time.sleep(4) - url = doc._manager._docweb.url - response = requests.get(url) - assert response.status_code == 200 - assert doc2._curr_group == 'test_group2' - assert doc2.manager == doc.manager - doc.stop() + dataset_path = self._build_dataset() + doc = Document(dataset_path, manager='ui') + try: + doc.create_kb_group(name='test_group') + doc2 = Document(dataset_path, manager=doc.manager, name='test_group2') + assert hasattr(doc._manager, '_docweb') + assert doc2._curr_group == 'test_group2' + assert doc2.manager == doc.manager + finally: + doc.stop() + + def test_doc_web_module_uses_workspace_pythonpath(self): + dataset_path = self._build_dataset() + calls = {} + + class FakeDocumentProcessor: + def __init__(self, *args, **kwargs): + calls['processor_pythonpath'] = kwargs.get('pythonpath') + self.url = 'http://127.0.0.1:19001/generate' + + def start(self): + calls['processor_started'] = True + + class FakeDocServer: + def __init__(self, *args, **kwargs): + calls['doc_server_pythonpath'] = kwargs.get('pythonpath') + calls['parser_url'] = kwargs.get('parser_url') + self._url = 'http://127.0.0.1:19002/generate' + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ + patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): + doc = Document(dataset_path, manager='ui') + try: + assert calls['processor_started'] is True + assert calls['processor_pythonpath'] == document_module._LOCAL_PYTHONPATH + assert calls['doc_server_pythonpath'] == document_module._LOCAL_PYTHONPATH + assert calls['parser_url'] == 'http://127.0.0.1:19001/generate' + finally: + doc.stop() def test_get_nodes(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) doc.activate_groups(groups=['chunk1']) @@ -301,7 +349,7 @@ def test_get_nodes(self): assert n.number == 2 def test_get_window_nodes(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=128, chunk_overlap=12))) doc.activate_groups(groups=['chunk1']) From 2ac3cf9303e9662ef301796d918b99f477c8ea0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 30 Mar 2026 18:45:47 +0800 Subject: [PATCH 32/46] fix: honor excluded metadata in text splitter --- lazyllm/tools/rag/transform/base.py | 9 +-------- tests/basic_tests/RAG/test_transform.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index 67d64b8d9..52cac8313 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -1,12 +1,11 @@ from copy import copy as copy_obj -from enum import Enum from dataclasses import dataclass, field from typing import ( Any, List, Union, Optional, Tuple, AbstractSet, Collection, Literal, Callable, Dict, Iterator ) from lazyllm import LOG -from ..doc_node import DocNode, RichDocNode +from ..doc_node import DocNode, RichDocNode, MetadataMode from ....common.deprecated import deprecated from lazyllm import ThreadPoolExecutor from itertools import chain @@ -21,12 +20,6 @@ from lazyllm.thirdparty import nltk from lazyllm.thirdparty import transformers -class MetadataMode(str, Enum): - ALL = 'ALL' - EMBED = 'EMBED' - LLM = 'LLM' - NONE = 'NONE' - @dataclass class _Split: text: str diff --git a/tests/basic_tests/RAG/test_transform.py b/tests/basic_tests/RAG/test_transform.py index 79244aee2..a216d566b 100644 --- a/tests/basic_tests/RAG/test_transform.py +++ b/tests/basic_tests/RAG/test_transform.py @@ -1176,6 +1176,23 @@ def test_get_metadata_size(self): metadata_size = splitter._get_metadata_size(node) assert metadata_size == 0 + def test_get_metadata_size_respects_excluded_metadata_keys(self): + splitter = _TextSplitterBase(chunk_size=200, overlap=10) + node = DocNode( + text='Hello, world! This is a test.', + metadata={ + 'file_name': 'test.pdf', + 'title': 'Section 1', + 'lines': [{'content': 'x' * 2000, 'page': 1}], + }, + ) + node.excluded_embed_metadata_keys = ['lines'] + node.excluded_llm_metadata_keys = ['lines'] + + metadata_size = splitter._get_metadata_size(node) + + assert metadata_size < 200 + def test_transform_returns_chunks(self, doc_node): splitter = _TextSplitterBase(chunk_size=20, overlap=10) chunks = splitter([doc_node]) From 6cc9176621b16cdfaf76f5554d084e11622208c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 30 Mar 2026 18:59:26 +0800 Subject: [PATCH 33/46] fix: guard doc path conflicts in doc manager --- lazyllm/tools/rag/doc_service/doc_manager.py | 200 ++++------ .../basic_tests/RAG/test_doc_service_mock.py | 345 +++++++++++++++++- 2 files changed, 418 insertions(+), 127 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index e522de9a1..913abea52 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -236,8 +236,7 @@ def set_callback_url(self, callback_url: str): def _ensure_indexes(self): stmts = [ - 'DROP INDEX IF EXISTS uq_docs_path', - 'CREATE INDEX IF NOT EXISTS idx_docs_path ON lazyllm_documents(path)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_docs_path ON lazyllm_documents(path)', 'CREATE INDEX IF NOT EXISTS idx_documents_upload_status ON lazyllm_documents(upload_status)', 'CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON lazyllm_documents(updated_at)', 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_display_name ' @@ -589,6 +588,12 @@ def _get_doc(self, doc_id: str): row = session.query(Doc).filter(Doc.doc_id == doc_id).first() return _orm_to_dict(row) if row else None + def _get_doc_by_path(self, path: str): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.path == path).first() + return _orm_to_dict(row) if row else None + def _upsert_doc( self, doc_id: str, @@ -602,34 +607,51 @@ def _upsert_doc( file_type = os.path.splitext(path)[1].lstrip('.').lower() or None size_bytes = os.path.getsize(path) if os.path.exists(path) else None content_hash = _sha256_file(path) if os.path.exists(path) else None - with self._db_manager.get_session() as session: - Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) - row = session.query(Doc).filter(Doc.doc_id == doc_id).first() - if row is None: - row = Doc( - doc_id=doc_id, - filename=filename, - path=path, - meta=_to_json(metadata), - upload_status=upload_status.value, - source_type=source_type.value, - file_type=file_type, - content_hash=content_hash, - size_bytes=size_bytes, - created_at=now, - updated_at=now, - ) - else: - row.filename = filename - row.path = path - row.meta = _to_json(metadata) - row.upload_status = upload_status.value - row.source_type = source_type.value - row.file_type = file_type - row.content_hash = content_hash - row.size_bytes = size_bytes - row.updated_at = now - session.add(row) + try: + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + path_row = session.query(Doc).filter(Doc.path == path).first() + if path_row is not None and path_row.doc_id != doc_id: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': path_row.doc_id, 'path': path}, + ) + if row is None: + row = Doc( + doc_id=doc_id, + filename=filename, + path=path, + meta=_to_json(metadata), + upload_status=upload_status.value, + source_type=source_type.value, + file_type=file_type, + content_hash=content_hash, + size_bytes=size_bytes, + created_at=now, + updated_at=now, + ) + else: + row.filename = filename + row.path = path + row.meta = _to_json(metadata) + row.upload_status = upload_status.value + row.source_type = source_type.value + row.file_type = file_type + row.content_hash = content_hash + row.size_bytes = size_bytes + row.updated_at = now + session.add(row) + except IntegrityError as exc: + existing = self._get_doc_by_path(path) + if existing is not None and existing.get('doc_id') != doc_id: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': existing['doc_id'], 'path': path}, + ) from exc + raise return self._get_doc(doc_id) def _set_doc_upload_status(self, doc_id: str, status: DocStatus): @@ -877,14 +899,13 @@ def _call_parser_client(self, method, *args, **kwargs): def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, parser_kb_id: Optional[str] = None, - transfer_params: Optional[Dict[str, Any]] = None, - parser_doc_id: Optional[str] = None): + transfer_params: Optional[Dict[str, Any]] = None): if task_type in (TaskType.DOC_ADD, TaskType.DOC_TRANSFER): if not file_path: raise RuntimeError(f'file_path is required for task_type {task_type.value}') task_resp = self._call_parser_client( self._parser_client.add_doc, - task_id, algo_id, parser_kb_id or kb_id, parser_doc_id or doc_id, file_path, metadata, + task_id, algo_id, parser_kb_id or kb_id, doc_id, file_path, metadata, callback_url=self._callback_url, transfer_params=transfer_params, ) elif task_type == TaskType.DOC_REPARSE: @@ -917,7 +938,7 @@ def _enqueue_task( file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, parser_kb_id: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None, - extra_message: Optional[Dict[str, Any]] = None, parser_doc_id: Optional[str] = None, + extra_message: Optional[Dict[str, Any]] = None, ): task_id = str(uuid4()) task_message = { @@ -957,7 +978,7 @@ def _enqueue_task( self._create_parser_task( task_id, doc_id, kb_id, algo_id, task_type, file_path=file_path, metadata=metadata, reparse_group=reparse_group, - parser_kb_id=parser_kb_id, transfer_params=transfer_params, parser_doc_id=parser_doc_id, + parser_kb_id=parser_kb_id, transfer_params=transfer_params, ) except Exception as exc: finished_at = now_ts() @@ -990,10 +1011,6 @@ def _enqueue_task( def _apply_doc_upload_status(self, doc_id: str, task_type: TaskType, status: DocStatus): if task_type == TaskType.DOC_ADD: - self._set_doc_upload_status(doc_id, status) - return - if task_type == TaskType.DOC_TRANSFER: - self._set_doc_upload_status(doc_id, status) return if task_type == TaskType.DOC_DELETE: if status == DocStatus.DELETING: @@ -1088,7 +1105,7 @@ def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[D prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) return prepared_items - def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: # noqa: C901 + def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: prepared_items = [] seen_pairs = set() seen_targets = set() @@ -1097,33 +1114,21 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An raise DocServiceError( 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} ) - if item.target_doc_id == item.doc_id: - raise DocServiceError( - 'E_INVALID_PARAM', - 'target_doc_id must be different from source doc_id', - {'doc_id': item.doc_id, 'target_doc_id': item.target_doc_id}, - ) - item_key = (item.doc_id, item.source_kb_id, item.target_kb_id, item.target_doc_id) + item_key = (item.doc_id, item.source_kb_id, item.target_kb_id) if item_key in seen_pairs: raise DocServiceError( 'E_INVALID_PARAM', 'duplicate transfer item detected', - { - 'doc_id': item.doc_id, - 'source_kb_id': item.source_kb_id, - 'target_kb_id': item.target_kb_id, - 'target_doc_id': item.target_doc_id, - }, + {'doc_id': item.doc_id, 'source_kb_id': item.source_kb_id, 'target_kb_id': item.target_kb_id}, ) seen_pairs.add(item_key) - target_key = (item.target_doc_id, item.target_kb_id, item.target_algo_id) + target_key = (item.doc_id, item.target_kb_id, item.target_algo_id) if target_key in seen_targets: raise DocServiceError( 'E_INVALID_PARAM', 'duplicate transfer target detected', { 'doc_id': item.doc_id, - 'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, }, @@ -1145,14 +1150,11 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') - if ( - self._get_doc(item.target_doc_id) is not None - or self._has_kb_document(item.target_kb_id, item.target_doc_id) - ): + if self._has_kb_document(item.target_kb_id, item.doc_id): raise DocServiceError( 'E_STATE_CONFLICT', - f'doc already exists in target kb: {item.target_doc_id}', - {'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id}, + f'doc already exists in target kb: {item.doc_id}', + {'doc_id': item.doc_id, 'target_kb_id': item.target_kb_id}, ) source_snapshot = self._get_parse_snapshot(item.doc_id, item.source_kb_id, item.source_algo_id) if source_snapshot is None or source_snapshot.get('status') != DocStatus.SUCCESS.value: @@ -1166,53 +1168,19 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An 'status': source_snapshot.get('status') if source_snapshot else None, }, ) - source_metadata = _from_json(doc.get('meta')) - target_metadata = dict(source_metadata) - if item.target_metadata: - target_metadata.update(item.target_metadata) - source_path = doc.get('path') - target_path = item.target_file_path - target_filename = item.target_filename - if target_path and target_filename: - resolved_name = os.path.basename(target_path) - if resolved_name and resolved_name != target_filename: - raise DocServiceError( - 'E_INVALID_PARAM', - 'target_filename must match the basename of target_file_path', - { - 'target_doc_id': item.target_doc_id, - 'target_filename': target_filename, - 'target_file_path': target_path, - }, - ) - if target_path and not target_filename: - target_filename = os.path.basename(target_path) or doc.get('filename') - if target_filename and not target_path: - if source_path: - target_path = os.path.join(os.path.dirname(source_path), target_filename) - else: - target_path = target_filename - if not target_filename: - target_filename = doc.get('filename') - if not target_path: - target_path = source_path prepared_items.append({ 'doc_id': item.doc_id, - 'target_doc_id': item.target_doc_id, 'source_kb_id': item.source_kb_id, 'source_algo_id': item.source_algo_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, 'mode': item.mode, - 'filename': target_filename, - 'source_type': SourceType(doc.get('source_type')), - 'source_file_path': source_path, - 'target_file_path': target_path, - 'metadata': target_metadata, + 'file_path': doc.get('path'), + 'metadata': _from_json(doc.get('meta')), 'transfer_params': { 'mode': 'mv' if item.mode == 'move' else 'cp', 'target_algo_id': item.target_algo_id, - 'target_doc_id': item.target_doc_id, + 'target_doc_id': item.doc_id, 'target_kb_id': item.target_kb_id, }, }) @@ -1232,7 +1200,7 @@ def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: path=file_path, metadata=metadata, source_type=request.source_type, - upload_status=DocStatus.WAITING, + upload_status=DocStatus.SUCCESS, ) self._ensure_kb_document(request.kb_id, doc_id) try: @@ -1359,26 +1327,15 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: for item in prepared_items: task_id = None try: - self._upsert_doc( - doc_id=item['target_doc_id'], - filename=item['filename'], - path=item['target_file_path'], - metadata=item['metadata'], - source_type=item['source_type'], - upload_status=DocStatus.WAITING, - ) - self._ensure_kb_document(item['target_kb_id'], item['target_doc_id']) + self._ensure_kb_document(item['target_kb_id'], item['doc_id']) task_id, snapshot = self._enqueue_task( - item['target_doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, + item['doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, idempotency_key=request.idempotency_key, - file_path=item['source_file_path'], + file_path=item['file_path'], metadata=item['metadata'], parser_kb_id=item['source_kb_id'], transfer_params=item['transfer_params'], - parser_doc_id=item['doc_id'], extra_message={ - 'source_doc_id': item['doc_id'], - 'target_doc_id': item['target_doc_id'], 'source_kb_id': item['source_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_kb_id': item['target_kb_id'], @@ -1390,9 +1347,7 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: error_msg = None accepted = True except Exception as exc: - snapshot = self._get_parse_snapshot( - item['target_doc_id'], item['target_kb_id'], item['target_algo_id'] - ) or {} + snapshot = self._get_parse_snapshot(item['doc_id'], item['target_kb_id'], item['target_algo_id']) or {} task_id = task_id or snapshot.get('current_task_id') error_code = snapshot.get('last_error_code') if not error_code: @@ -1401,14 +1356,12 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: accepted = False items.append({ 'doc_id': item['doc_id'], - 'target_doc_id': item['target_doc_id'], 'task_id': task_id, 'source_kb_id': item['source_kb_id'], 'target_kb_id': item['target_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_algo_id': item['target_algo_id'], 'mode': item['mode'], - 'target_file_path': item['target_file_path'], 'status': snapshot.get('status', DocStatus.FAILED.value), 'accepted': accepted, 'error_code': error_code, @@ -1532,8 +1485,6 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 ) if task_type == TaskType.DOC_ADD: self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) - elif task_type == TaskType.DOC_TRANSFER: - self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) elif task_type == TaskType.DOC_DELETE: self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) return {'ack': True, 'deduped': False, 'ignored_reason': None} @@ -1553,11 +1504,10 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 and task_message.get('mode') == 'move' ): source_kb_id = task_message.get('source_kb_id') - source_doc_id = task_message.get('source_doc_id') - if source_kb_id and source_doc_id and source_kb_id != kb_id: - self._remove_kb_document(source_kb_id, source_doc_id) - self._delete_parse_snapshots(source_doc_id, source_kb_id) - self._sync_doc_upload_status(source_doc_id) + if source_kb_id and source_kb_id != kb_id: + self._remove_kb_document(source_kb_id, doc_id) + self._delete_parse_snapshots(doc_id, source_kb_id) + self._sync_doc_upload_status(doc_id) self._update_task_record( callback.task_id, diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py index 882acbebf..d585e3d75 100644 --- a/tests/basic_tests/RAG/test_doc_service_mock.py +++ b/tests/basic_tests/RAG/test_doc_service_mock.py @@ -5,11 +5,20 @@ import tempfile import time from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from uuid import uuid4 import pytest import requests from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, CallbackEventType, DeleteRequest, DocServiceError, DocStatus, KbUpdateRequest, ReparseRequest, + SourceType, TaskCallbackRequest, UploadRequest, +) +from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient +from lazyllm.tools.rag.parsing_service.base import TaskType +from lazyllm.tools.rag.utils import BaseResponse @pytest.mark.skip_on_win @@ -160,7 +169,6 @@ def test_p0_endpoints_and_core_flows(self): 'items': [ { 'doc_id': doc_add, - 'target_doc_id': 'copied-seed-doc-1', 'source_kb_id': 'kb_a', 'source_algo_id': '__default__', 'target_kb_id': 'kb_b', @@ -350,6 +358,36 @@ def test_upload_idempotency_replay_and_conflict(self): assert conflict.status_code == 409 assert conflict.json()['data']['biz_code'] == 'E_IDEMPOTENCY_CONFLICT' + def test_add_same_path_with_different_doc_id_returns_conflict(self): + create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_path_conflict'}, timeout=5) + assert create_kb.status_code == 200 + + first = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_path_conflict', + 'algo_id': '__default__', + 'items': [{'file_path': self._seed_path, 'doc_id': 'path-doc-1'}], + }, + timeout=8, + ) + assert first.status_code == 200 + + conflict = requests.post( + f'{self.base_url}/v1/docs/add', + json={ + 'kb_id': 'kb_path_conflict', + 'algo_id': '__default__', + 'items': [{'file_path': self._seed_path, 'doc_id': 'path-doc-2'}], + }, + timeout=8, + ) + assert conflict.status_code == 409 + body = conflict.json() + assert body['data']['biz_code'] == 'E_STATE_CONFLICT' + assert body['data']['path'] == self._seed_path + assert body['data']['doc_id'] == 'path-doc-1' + def test_upload_same_filename_does_not_override_existing_file(self): create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_same_name'}, timeout=5) assert create_kb.status_code == 200 @@ -514,7 +552,6 @@ def test_kb_algo_binding_and_transfer_validation(self): json={ 'items': [{ 'doc_id': doc_id, - 'target_doc_id': 'invalid-transfer-doc', 'source_kb_id': 'kb_bind', 'source_algo_id': '__default__', 'target_kb_id': 'kb_bind', @@ -653,3 +690,307 @@ def test_kb_update_pagination_and_batch_query(self): assert len(batch_data['items']) == 1 assert batch_data['items'][0]['kb_id'] == 'kb_page_1' assert batch_data['missing_kb_ids'] == ['kb_missing'] + + +class TestDocServiceMockLocal: + @classmethod + def setup_class(cls): + cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_local_') + cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') + with open(cls._seed_path, 'w', encoding='utf-8') as f: + f.write('local seed content') + cls._db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(cls._tmp_dir, 'doc_service_local.db'), + } + cls.manager = DocManager(db_config=cls._db_config, parser_url='http://parser.test') + cls._pending_task_status = {} + + def _queue_task(task_id: str, final_status: DocStatus): + cls._pending_task_status[task_id] = final_status + + cls.manager._parser_client.add_doc = lambda task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.update_meta = lambda task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.delete_doc = lambda task_id, algo_id, kb_id, doc_id: ( + _queue_task(task_id, DocStatus.SUCCESS) or + BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + ) + cls.manager._parser_client.cancel_task = lambda task_id: BaseResponse( + code=200, msg='success', data={'task_id': task_id, 'cancel_status': True} + ) + cls.manager._parser_client.list_algorithms = lambda: BaseResponse( + code=200, msg='success', data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}] + ) + cls.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( + code=200, + msg='success', + data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, + ) + + @classmethod + def teardown_class(cls): + shutil.rmtree(cls._tmp_dir, ignore_errors=True) + + def _wait_task(self, task_id, target_statuses, timeout=8): + deadline = time.time() + timeout + last = None + while time.time() < deadline: + resp = self.manager.get_task(task_id) + assert resp.code == 200 + last = resp.data + if last['status'] in target_statuses: + return last + pending_status = self._pending_task_status.pop(task_id, None) + if pending_status is not None: + self.manager.on_task_callback(TaskCallbackRequest( + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=pending_status, + )) + time.sleep(0.05) + raise AssertionError(f'task {task_id} not finished in time, last={last}') + + def _make_file(self, name: str, content: str): + file_path = os.path.join(self._tmp_dir, name) + with open(file_path, 'w', encoding='utf-8') as f: + f.write(content) + return file_path + + def test_manager_atomic_idempotency(self): + started = [] + + def handler(): + started.append(time.time()) + time.sleep(0.2) + return {'task_id': str(uuid4())} + + with ThreadPoolExecutor(max_workers=2) as pool: + future = pool.submit(self.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) + time.sleep(0.05) + with pytest.raises(DocServiceError) as exc: + self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + result = future.result(timeout=2) + + assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' + replay = self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + assert len(started) == 1 + assert replay == result + + def test_manager_kb_algo_binding(self): + self.manager.create_kb('kb_local_bind', algo_id='__default__') + file_path = self._make_file('local_bind.txt', 'local bind content') + with pytest.raises(DocServiceError) as exc: + self.manager.upload(UploadRequest( + kb_id='kb_local_bind', + algo_id='wrong_algo', + items=[AddFileItem(file_path=file_path, doc_id='local-bind-doc')], + )) + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_stale_callback_and_state_conflict(self): + self.manager.create_kb('kb_local_stale', algo_id='__default__') + file_path = self._make_file('local_stale.txt', 'local stale content') + uploaded = self.manager.upload(UploadRequest( + kb_id='kb_local_stale', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='local-stale-doc')], + )) + self._wait_task(uploaded[0]['task_id'], {'SUCCESS'}) + first_task_id = self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + ))[0] + second_task_id = self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + ))[0] + stale_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-stale-callback', + task_id=first_task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + )) + assert stale_resp['ignored_reason'] == 'stale_task_callback' + self.manager.delete(DeleteRequest(kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'])) + with pytest.raises(DocServiceError) as exc: + self.manager.reparse(ReparseRequest( + kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], + )) + assert exc.value.biz_code == 'E_STATE_CONFLICT' + assert second_task_id != first_task_id + + def test_manager_missing_endpoint_surrogates(self): + self.manager.create_kb('kb_local_info', algo_id='__default__') + file_path = self._make_file('local_info.txt', 'local info content') + uploaded = self.manager.upload(UploadRequest( + kb_id='kb_local_info', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='local-info-doc')], + )) + algorithms = self.manager.list_algorithms_compat() + assert len(algorithms['items']) >= 1 + algo_info = self.manager.get_algorithm_info('__default__') + assert algo_info['algo_id'] == '__default__' + chunks = self.manager.list_chunks() + assert chunks['items'] == [] + tasks_batch = self.manager.get_tasks_batch([uploaded[0]['task_id']]) + assert len(tasks_batch['items']) == 1 + + def test_delete_kbs_empty_list_rejected(self): + with pytest.raises(DocServiceError) as exc: + self.manager.delete_kbs([]) + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_rejects_unknown_kb_algorithm(self): + with pytest.raises(DocServiceError) as exc: + self.manager.create_kb('kb_local_unknown_algo', algo_id='missing_algo') + assert exc.value.biz_code == 'E_INVALID_PARAM' + + def test_manager_update_kb_can_clear_nullable_fields(self): + self.manager.create_kb( + 'kb_local_clearable', + display_name='Clearable', + description='to be cleared', + owner_id='owner-x', + meta={'tag': 'x'}, + algo_id='__default__', + ) + updated = self.manager.update_kb( + 'kb_local_clearable', + display_name=None, + description=None, + owner_id=None, + meta=None, + explicit_fields={'display_name', 'description', 'owner_id', 'meta'}, + ) + assert updated['display_name'] is None + assert updated['description'] is None + assert updated['owner_id'] is None + assert updated['meta'] == {} + + def test_kb_update_idempotency_payload_distinguishes_omitted_and_null(self): + keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') + clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') + + keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) + clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) + + assert keep_payload != clear_payload + + self.manager.run_idempotent( + '/v1/kbs/kb_local_idem:patch', + 'kb-update-idem', + keep_payload, + lambda: {'kb_id': 'kb_local_idem', 'owner_id': 'kept'}, + ) + with pytest.raises(DocServiceError) as exc: + self.manager.run_idempotent( + '/v1/kbs/kb_local_idem:patch', + 'kb-update-idem', + clear_payload, + lambda: {'kb_id': 'kb_local_idem', 'owner_id': None}, + ) + assert exc.value.biz_code == 'E_IDEMPOTENCY_CONFLICT' + + def test_manager_callback_payload_fallback_and_delete_transition(self): + self.manager.create_kb('kb_local_callback', algo_id='__default__') + file_path = self._make_file('local_callback.txt', 'local callback content') + self.manager._upsert_doc( + doc_id='local-callback-doc', + filename='local_callback.txt', + path=file_path, + metadata={'case': 'callback'}, + source_type=SourceType.EXTERNAL, + ) + self.manager._ensure_kb_document('kb_local_callback', 'local-callback-doc') + queued_at = self.manager._upsert_parse_snapshot( + doc_id='local-callback-doc', + kb_id='kb_local_callback', + algo_id='__default__', + status=DocStatus.DELETING, + task_type=TaskType.DOC_DELETE, + current_task_id='local-delete-task', + queued_at=datetime.now(), + )['queued_at'] + + start_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-delete-start', + task_id='local-delete-task', + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'local-callback-doc', + 'kb_id': 'kb_local_callback', + 'algo_id': '__default__', + }, + )) + assert start_resp['ack'] is True + start_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') + assert start_snapshot['status'] == DocStatus.DELETING.value + assert start_snapshot['queued_at'] == queued_at + + finish_resp = self.manager.on_task_callback(TaskCallbackRequest( + callback_id='local-delete-finish', + task_id='local-delete-task', + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'local-callback-doc', + 'kb_id': 'kb_local_callback', + 'algo_id': '__default__', + }, + )) + assert finish_resp['ack'] is True + + finish_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') + assert finish_snapshot['status'] == DocStatus.DELETED.value + assert self.manager._has_kb_document('kb_local_callback', 'local-callback-doc') is False + assert self.manager._get_doc('local-callback-doc')['upload_status'] == DocStatus.DELETED.value + + def test_parser_client_algo_endpoint_fallback(self): + client = _ParserClient(parser_url='http://parser.test') + calls = [] + + def fake_get(path, params=None): + del params + calls.append(path) + if path == '/v1/algo/list': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/list': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + } + if path == '/v1/algo/__default__/groups': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/__default__/group/info': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], + } + raise AssertionError(path) + + client._get = fake_get + algo_resp = client.list_algorithms() + group_resp = client.get_algorithm_groups('__default__') + + assert algo_resp.code == 200 + assert group_resp.code == 200 + assert calls == [ + '/v1/algo/list', + '/algo/list', + '/v1/algo/__default__/groups', + '/algo/__default__/group/info', + ] From 75415059eee1f2d81059b013babc98d0916e5761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 30 Mar 2026 19:04:07 +0800 Subject: [PATCH 34/46] fix: make mineru pipeline backend more robust --- .../tools/servers/mineru/mineru_patches.py | 22 +++++++++++++------ .../servers/mineru/mineru_server_module.py | 7 +++--- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/lazyllm/tools/servers/mineru/mineru_patches.py b/lazyllm/tools/servers/mineru/mineru_patches.py index 4d061dd9e..1b4ec05b4 100644 --- a/lazyllm/tools/servers/mineru/mineru_patches.py +++ b/lazyllm/tools/servers/mineru/mineru_patches.py @@ -6,17 +6,24 @@ merge_para_with_text as pipeline_merge_para_with_text, get_title_level, ) -from mineru.backend.vlm import ( # noqa: NID002 - vlm_middle_json_mkcontent, -) -from mineru.backend.vlm.vlm_middle_json_mkcontent import ( # noqa: NID002 - merge_para_with_text as vlm_merge_para_with_text, -) from mineru.utils.enum_class import ( # noqa: NID002 BlockType, ContentType, ) +try: + from mineru.backend.vlm import ( # noqa: NID002 + vlm_middle_json_mkcontent, + ) + from mineru.backend.vlm.vlm_middle_json_mkcontent import ( # noqa: NID002 + merge_para_with_text as vlm_merge_para_with_text, + ) + HAS_VLM_BACKEND = True +except Exception: # pragma: no cover - optional dependency path + vlm_middle_json_mkcontent = None + vlm_merge_para_with_text = None + HAS_VLM_BACKEND = False + # patches to mineru (to output bbox) def _parse_line_spans(para_block, page_idx): @@ -243,4 +250,5 @@ def vlm_make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_s para_content['page_height'] = page_height return para_content -vlm_middle_json_mkcontent.make_blocks_to_content_list = vlm_make_blocks_to_content_list +if HAS_VLM_BACKEND: + vlm_middle_json_mkcontent.make_blocks_to_content_list = vlm_make_blocks_to_content_list diff --git a/lazyllm/tools/servers/mineru/mineru_server_module.py b/lazyllm/tools/servers/mineru/mineru_server_module.py index de5419f7a..99916321d 100644 --- a/lazyllm/tools/servers/mineru/mineru_server_module.py +++ b/lazyllm/tools/servers/mineru/mineru_server_module.py @@ -853,10 +853,11 @@ async def parse_pdf( # noqa: C901 try: results = {f: {} for f in files} + cache_enabled = bool(self.cache_manager and self.cache_manager.is_enabled) # honor use_cache for parsed outputs files_to_process = files - if use_cache and self.cache_manager.is_enabled: + if use_cache and cache_enabled: results, files_to_process = await self.cache_manager.check_parse_cache( files, results, effective_backend, return_md, return_content_list, table_enable, formula_enable, lang=lang, parse_method=parse_method @@ -872,7 +873,7 @@ async def parse_pdf( # noqa: C901 status_code=200, content={'result': results_list, 'unique_id': unique_id} ) - elif use_cache and not self.cache_manager.is_enabled: + elif use_cache and not cache_enabled: LOG.warning(f'[{req_id}] CACHE_DIR not set; use_cache ignored.') # 1) Convert to PDFs with caching support @@ -965,7 +966,7 @@ async def parse_pdf( # noqa: C901 results[src_path]['md_content'] = md_content if return_content_list and content_list: results[src_path]['content_list'] = content_list - if self.cache_manager.is_enabled: + if cache_enabled: await self.cache_manager.save_parse_result( hash_id, results[src_path], mode=effective_backend, table_enable=table_enable, formula_enable=formula_enable, From 084060b6e98ee428b54547cbf14e2a1e9665634e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Mon, 30 Mar 2026 19:55:21 +0800 Subject: [PATCH 35/46] fix: reset doc impl monitor lock on pickle --- lazyllm/tools/rag/doc_impl.py | 3 ++- tests/basic_tests/RAG/test_document.py | 9 +++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 8cf651832..9b3b05ee7 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -4,7 +4,7 @@ from enum import Enum from pydantic import BaseModel from typing import Callable, Dict, List, Optional, Set, Union, Tuple, Any, Type -from lazyllm import LOG, once_wrapper +from lazyllm import LOG, once_wrapper, reset_on_pickle from lazyllm.module import LLMBase from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, TransformArgs, TransformArgs as TArgs) @@ -63,6 +63,7 @@ def __hash__(self): return hash(self.name) lambda x: x._content, LAZY_IMAGE_GROUP, True) +@reset_on_pickle(('_local_monitor_lock', threading.Lock)) class DocImpl: _builtin_node_groups: Dict[str, Dict] = {} _global_node_groups: Dict[str, Dict] = {} diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 2f06df6ba..be598f006 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -1,4 +1,5 @@ import lazyllm +import cloudpickle from lazyllm.tools.rag.doc_impl import DocImpl from lazyllm.tools.rag.store.store_base import LAZY_IMAGE_GROUP from lazyllm.tools.rag.transform import SentenceSplitter @@ -151,6 +152,14 @@ def wait_for_doc_ids(expected_ids): if doc_impl._local_monitor_thread: doc_impl._local_monitor_thread.join(timeout=1) + def test_doc_impl_can_be_pickled_before_lazy_init(self): + doc_impl = DocImpl(embed=self.mock_embed, doc_files=[self.tmp_file_a.name]) + serialized = cloudpickle.dumps(doc_impl) + restored = cloudpickle.loads(serialized) + + assert restored._local_monitor_lock is not None + assert restored._local_monitor_thread is None + class TestDocument(unittest.TestCase): @classmethod def tearDownClass(cls): From 251259d7f20b7d8b7e4ec3dd43d1e91d1b12e483 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 1 Apr 2026 21:14:19 +0800 Subject: [PATCH 36/46] fix opensearch segment deserialization --- lazyllm/tools/rag/store/segment/opensearch_store.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lazyllm/tools/rag/store/segment/opensearch_store.py b/lazyllm/tools/rag/store/segment/opensearch_store.py index a210e4612..0368bdfcd 100644 --- a/lazyllm/tools/rag/store/segment/opensearch_store.py +++ b/lazyllm/tools/rag/store/segment/opensearch_store.py @@ -274,9 +274,15 @@ def _serialize_node(self, segment: dict): def _deserialize_node(self, segment: dict) -> dict: seg = dict(segment) if self._global_metadata_desc and self._global_metadata_desc == BUILDIN_GLOBAL_META_DESC: - seg['meta'] = json.loads(seg.get('meta', '{}')) - seg['global_meta'] = json.loads(seg.get('global_meta', '{}')) - seg['image_keys'] = json.loads(seg.get('image_keys', '[]')) + meta = seg.get('meta', '{}') + global_meta = seg.get('global_meta', '{}') + image_keys = seg.get('image_keys', '[]') + if isinstance(meta, (str, bytes, bytearray)): + seg['meta'] = json.loads(meta) + if isinstance(global_meta, (str, bytes, bytearray)): + seg['global_meta'] = json.loads(global_meta) + if isinstance(image_keys, (str, bytes, bytearray)): + seg['image_keys'] = json.loads(image_keys) return seg def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: C901 From b9c835db85a154ae770f3e6f05f9b301c35aad00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 2 Apr 2026 10:22:50 +0800 Subject: [PATCH 37/46] fix doc service mock test lint --- .../basic_tests/RAG/test_doc_service_mock.py | 41 +++++++++++++------ 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py index d585e3d75..f134c0649 100644 --- a/tests/basic_tests/RAG/test_doc_service_mock.py +++ b/tests/basic_tests/RAG/test_doc_service_mock.py @@ -713,18 +713,35 @@ def setup_class(cls): def _queue_task(task_id: str, final_status: DocStatus): cls._pending_task_status[task_id] = final_status - cls.manager._parser_client.add_doc = lambda task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) - cls.manager._parser_client.update_meta = lambda task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) - cls.manager._parser_client.delete_doc = lambda task_id, algo_id, kb_id, doc_id: ( - _queue_task(task_id, DocStatus.SUCCESS) or - BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) - ) + def _add_doc(task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None): + _queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, + ) + + cls.manager._parser_client.add_doc = _add_doc + + def _update_meta(task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None): + _queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, + ) + + cls.manager._parser_client.update_meta = _update_meta + + def _delete_doc(task_id, algo_id, kb_id, doc_id): + _queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, + ) + + cls.manager._parser_client.delete_doc = _delete_doc cls.manager._parser_client.cancel_task = lambda task_id: BaseResponse( code=200, msg='success', data={'task_id': task_id, 'cancel_status': True} ) From 5f7c27fee43d9dccf38ca721f1eb7f533df159eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 7 Apr 2026 16:12:16 +0800 Subject: [PATCH 38/46] Converge document manager parameters --- docs/lazyllm-skill/assets/rag/document.md | 5 +- examples/rag/doc_service_mock_example.py | 2 - .../rag_with_parsing_service/README.CN.md | 2 + examples/rag_with_parsing_service/README.md | 2 + lazyllm/docs/tools/tool_rag.py | 24 +- lazyllm/tools/rag/__init__.py | 2 + lazyllm/tools/rag/doc_impl.py | 2 +- lazyllm/tools/rag/doc_service/doc_manager.py | 89 ++++-- lazyllm/tools/rag/doc_service/doc_server.py | 20 ++ lazyllm/tools/rag/document.py | 106 +++++-- .../RAG/test_doc_service_doc_server.py | 14 + .../basic_tests/RAG/test_doc_service_mock.py | 19 -- tests/basic_tests/RAG/test_document.py | 269 +++++++++++++++++- 13 files changed, 469 insertions(+), 87 deletions(-) diff --git a/docs/lazyllm-skill/assets/rag/document.md b/docs/lazyllm-skill/assets/rag/document.md index 26e1f1eaa..7cee6a46b 100644 --- a/docs/lazyllm-skill/assets/rag/document.md +++ b/docs/lazyllm-skill/assets/rag/document.md @@ -91,8 +91,8 @@ def processYml(file): ... print("Call the function processYml.") ... return [DocNode(text=data)] ... -doc1 = Document(dataset_path="your_files_path", create_ui=False) -doc2 = Document(dataset_path="your_files_path", create_ui=False) +doc1 = Document(dataset_path="your_files_path") +doc2 = Document(dataset_path="your_files_path") doc1.add_reader("**/*.yml", YmlReader) print(doc1._impl._local_file_reader) {'**/*.yml': } @@ -213,4 +213,3 @@ res = retriever(query=query) print(f"answer: {res}") ``` - diff --git a/examples/rag/doc_service_mock_example.py b/examples/rag/doc_service_mock_example.py index 0fdfd1e53..0329cc588 100644 --- a/examples/rag/doc_service_mock_example.py +++ b/examples/rag/doc_service_mock_example.py @@ -32,7 +32,6 @@ def _wait_task(base_url: str, task_id: str, targets: set[str], timeout: float = def main(): parser = argparse.ArgumentParser(description='DocService mock quickstart.') parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API/docs inspection.') - parser.add_argument('--doc-server-port', type=int, default=None, help='DocServer listen port.') args = parser.parse_args() with tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_demo_') as tmp: @@ -46,7 +45,6 @@ def main(): dataset_path=storage, manager=True, name='demo_doc_service', - doc_server_port=args.doc_server_port, ) doc.start() diff --git a/examples/rag_with_parsing_service/README.CN.md b/examples/rag_with_parsing_service/README.CN.md index 787e40d74..20d92fe44 100644 --- a/examples/rag_with_parsing_service/README.CN.md +++ b/examples/rag_with_parsing_service/README.CN.md @@ -23,6 +23,8 @@ RAG 解析服务示例 参考 `document.py` 的配置: - `manager=DocumentProcessor(url="http://0.0.0.0:9966")` 指向解析服务。 +- 该模式下 `Document` 不会创建 `DocServer` 或 UI,`dataset_path` 也不会再用于本地扫盘。 +- 必须显式提供 `store_conf`,且不能使用纯 map store。 - `server=9977` 将 Document 作为服务暴露。 然后在 `retriever_using_url.py` 中使用: diff --git a/examples/rag_with_parsing_service/README.md b/examples/rag_with_parsing_service/README.md index bfd7c040e..441844351 100644 --- a/examples/rag_with_parsing_service/README.md +++ b/examples/rag_with_parsing_service/README.md @@ -26,6 +26,8 @@ others can access it remotely by URL, and Retrievers can use that URL directly. See `document.py` for the setup: - `manager=DocumentProcessor(url="http://0.0.0.0:9966")` points to the parsing service. +- In this mode `Document` does not create a `DocServer` or UI, and `dataset_path` is not used for local path monitoring. +- `store_conf` is required and must not be a pure map store. - `server=9977` exposes the Document as a service. Then use `retriever_using_url.py` to create: diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index 4e812f294..1a0e7698c 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -26,8 +26,8 @@ Args: dataset_path (Optional[str]): Path to the dataset directory. If not found, the system will attempt to locate it in ``lazyllm.config["data_path"]``. embed (Optional[Union[Callable, Dict[str, Callable]]]): Embedding function or mapping of embedding functions. When a dictionary is provided, keys are embedding names and values are embedding models. - create_ui (bool, optional): Deprecated alias of ``manager`` kept for compatibility. - manager (Union[bool, str], optional): Whether to enable the document manager. If ``True``, launches ``DocServer`` together with a local parsing service. If ``'ui'``, also enables the document management web UI. + create_ui (bool, optional): Whether to create the document-management UI. It requires an available ``DocServer`` and can be combined with ``manager=True`` or ``manager=DocServer(...)``. + manager (Union[bool, str, DocServer, Document._Manager, DocumentProcessor], optional): Document manager mode. ``True`` launches a local ``DocServer`` together with a local parsing service. ``DocServer(...)`` connects an existing document-management service. ``DocumentProcessor(...)`` connects a parsing service only and requires a non-map ``store_conf``. ``'ui'`` is accepted as a deprecated compatibility alias for ``manager=True, create_ui=True``. server (Union[bool, int], optional): Whether to run a server interface for knowledge bases. ``True`` enables a default server, an integer specifies a custom port, and ``False`` disables it. Defaults to ``False``. name (Optional[str]): Name identifier for this document collection. Defaults to the system default name. launcher (Optional[Launcher]): Launcher instance for managing server processes. Defaults to a remote asynchronous launcher. @@ -37,8 +37,7 @@ display_name (Optional[str]): Human-readable display name for this document module. Defaults to the collection name. description (Optional[str]): Description of the document collection. Defaults to ``"algorithm description"``. schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): Optional schema extractor used for metadata schema analysis and registration. - doc_server_port (Optional[int]): Explicit local port for ``DocServer`` when ``manager`` is enabled. - enable_path_monitoring (Optional[bool]): Whether to watch the local dataset path for file additions and removals. Defaults to enabled for local documents without manager mode. + enable_path_monitoring (Optional[bool]): Whether to watch the local dataset path for file additions and removals. Defaults to enabled only for local documents without ``DocServer``/``DocumentProcessor`` manager mode. ''') add_chinese_doc('Document', '''\ @@ -49,8 +48,8 @@ Args: dataset_path (Optional[str]): 数据集目录路径。如果路径不存在,系统会尝试在 ``lazyllm.config["data_path"]`` 中查找。 embed (Optional[Union[Callable, Dict[str, Callable]]]): 文档向量化函数或函数字典。若为字典,键为 embedding 名称,值为对应的模型。 - create_ui (bool, optional): ``manager`` 的兼容别名,已废弃。 - manager (Union[bool, str], optional): 是否启用文档管理服务。``True`` 表示启动 ``DocServer`` 及其本地 parsing service;``'ui'`` 表示同时启动 Web 管理界面。 + create_ui (bool, optional): 是否创建文档管理 UI。该能力要求当前存在可用的 ``DocServer``,可与 ``manager=True`` 或 ``manager=DocServer(...)`` 组合使用。 + manager (Union[bool, str, DocServer, Document._Manager, DocumentProcessor], optional): 文档管理模式。``True`` 表示启动本地 ``DocServer`` 及其 parsing service;``DocServer(...)`` 表示连接已有文档管理服务;``DocumentProcessor(...)`` 表示仅连接解析服务,此时必须提供非 map 的 ``store_conf``;``'ui'`` 仅作为 ``manager=True, create_ui=True`` 的兼容写法保留。 server (Union[bool, int], optional): 是否为知识库运行服务接口。``True`` 表示启动默认服务;整型数值表示自定义端口;``False`` 表示关闭。默认为 ``False``。 name (Optional[str]): 文档集合的名称标识符。默认为系统默认名称。 launcher (Optional[Launcher]): 启动器实例,用于管理服务进程。默认使用远程异步启动器。 @@ -60,8 +59,7 @@ display_name (Optional[str]): 文档模块的可读显示名称。默认为集合名称。 description (Optional[str]): 文档集合的描述。默认为 ``"algorithm description"``。 schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): 可选 schema extractor,用于元数据 schema 分析与注册。 - doc_server_port (Optional[int]): ``manager`` 启用时 ``DocServer`` 使用的本地端口。 - enable_path_monitoring (Optional[bool]): 是否监控本地数据目录的文件新增和删除。对非 manager 的本地文档默认开启。 + enable_path_monitoring (Optional[bool]): 是否监控本地数据目录的文件新增和删除。仅在未接入 ``DocServer`` / ``DocumentProcessor`` 的本地模式下默认开启。 ''') add_doc_service_english_doc('DocServer', '''\ @@ -657,8 +655,8 @@ ... data = f.read() ... return [DocNode(text=data)] ... ->>> doc1 = Document(dataset_path="your_files_path", create_ui=False) ->>> doc2 = Document(dataset_path="your_files_path", create_ui=False) +>>> doc1 = Document(dataset_path="your_files_path") +>>> doc2 = Document(dataset_path="your_files_path") >>> files = ["your_yml_files"] >>> docs1 = doc1._impl._reader.load_data(input_files=files) >>> docs2 = doc2._impl._reader.load_data(input_files=files) @@ -702,8 +700,8 @@ ... print("Call the function processYml.") ... return [DocNode(text=data)] ... ->>> doc1 = Document(dataset_path="your_files_path", create_ui=False) ->>> doc2 = Document(dataset_path="your_files_path", create_ui=False) +>>> doc1 = Document(dataset_path="your_files_path") +>>> doc2 = Document(dataset_path="your_files_path") >>> doc1.add_reader("**/*.yml", YmlReader) >>> print(doc1._impl._local_file_reader) {'**/*.yml': } @@ -1142,7 +1140,7 @@ ... return [DocNode(text=data)] ... >>> files = ["your_yml_files"] ->>> doc = Document(dataset_path="your_files_path", create_ui=False) +>>> doc = Document(dataset_path="your_files_path") >>> reader = doc._impl._reader.load_data(input_files=files) # Call the class YmlReader. ''') diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index f80c57eee..a99670771 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -4,6 +4,7 @@ # flake8: noqa: E402 from .document import Document +from .doc_manager import DocManager from .graph_document import GraphDocument, UrlGraphDocument from .retriever import Retriever, TempDocRetriever, ContextRetriever, WeightedRetriever, PriorityRetriever from .graph_retriever import GraphRetriever @@ -29,6 +30,7 @@ __all__ = [ 'add_post_action_for_default_reader', 'Document', + 'DocManager', 'GraphDocument', 'UrlGraphDocument', 'Reranker', diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 9b3b05ee7..c4e0dffa5 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -185,7 +185,7 @@ def _lazy_init(self) -> None: self._resolve_index_pending_registrations() if self._processor: - assert cloud and isinstance(self._processor, DocumentProcessor) + assert isinstance(self._processor, DocumentProcessor) self._processor.register_algorithm(self._algo_name, self._store, self._reader, self.node_groups, self._schema_extractor, self._display_name, self._description) else: diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 913abea52..a088b5698 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -84,6 +84,26 @@ def _sha256_file(file_path: str) -> str: return digest.hexdigest() +def _merge_transfer_metadata( + source_metadata: Dict[str, Any], target_metadata: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + metadata = dict(source_metadata or {}) + if target_metadata: + metadata.update(target_metadata) + return metadata + + +def _resolve_transfer_target_path( + source_path: str, target_filename: Optional[str], target_file_path: Optional[str] +) -> str: + if target_file_path: + return target_file_path + if target_filename: + base_dir = os.path.dirname(source_path) if source_path else '' + return os.path.join(base_dir, target_filename) if base_dir else target_filename + return source_path + + class _ParserClient: def __init__(self, parser_url: str): parser_url = parser_url.rstrip('/') @@ -602,17 +622,19 @@ def _upsert_doc( metadata: Dict[str, Any], source_type: SourceType, upload_status: DocStatus = DocStatus.SUCCESS, + allowed_path_doc_ids: Optional[Set[str]] = None, ): now = now_ts() file_type = os.path.splitext(path)[1].lstrip('.').lower() or None size_bytes = os.path.getsize(path) if os.path.exists(path) else None content_hash = _sha256_file(path) if os.path.exists(path) else None + allowed_path_doc_ids = allowed_path_doc_ids or set() try: with self._db_manager.get_session() as session: Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) row = session.query(Doc).filter(Doc.doc_id == doc_id).first() path_row = session.query(Doc).filter(Doc.path == path).first() - if path_row is not None and path_row.doc_id != doc_id: + if path_row is not None and path_row.doc_id != doc_id and path_row.doc_id not in allowed_path_doc_ids: raise DocServiceError( 'E_STATE_CONFLICT', f'doc path already exists: {path}', @@ -645,7 +667,11 @@ def _upsert_doc( session.add(row) except IntegrityError as exc: existing = self._get_doc_by_path(path) - if existing is not None and existing.get('doc_id') != doc_id: + if ( + existing is not None + and existing.get('doc_id') != doc_id + and existing.get('doc_id') not in allowed_path_doc_ids + ): raise DocServiceError( 'E_STATE_CONFLICT', f'doc path already exists: {path}', @@ -938,7 +964,7 @@ def _enqueue_task( file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, parser_kb_id: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None, - extra_message: Optional[Dict[str, Any]] = None, + extra_message: Optional[Dict[str, Any]] = None, parser_doc_id: Optional[str] = None, ): task_id = str(uuid4()) task_message = { @@ -976,7 +1002,7 @@ def _enqueue_task( ) try: self._create_parser_task( - task_id, doc_id, kb_id, algo_id, task_type, + task_id, parser_doc_id or doc_id, kb_id, algo_id, task_type, file_path=file_path, metadata=metadata, reparse_group=reparse_group, parser_kb_id=parser_kb_id, transfer_params=transfer_params, ) @@ -1122,13 +1148,13 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An {'doc_id': item.doc_id, 'source_kb_id': item.source_kb_id, 'target_kb_id': item.target_kb_id}, ) seen_pairs.add(item_key) - target_key = (item.doc_id, item.target_kb_id, item.target_algo_id) + target_key = (item.target_doc_id, item.target_kb_id, item.target_algo_id) if target_key in seen_targets: raise DocServiceError( 'E_INVALID_PARAM', 'duplicate transfer target detected', { - 'doc_id': item.doc_id, + 'doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, }, @@ -1150,11 +1176,11 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') - if self._has_kb_document(item.target_kb_id, item.doc_id): + if self._has_kb_document(item.target_kb_id, item.target_doc_id): raise DocServiceError( 'E_STATE_CONFLICT', - f'doc already exists in target kb: {item.doc_id}', - {'doc_id': item.doc_id, 'target_kb_id': item.target_kb_id}, + f'doc already exists in target kb: {item.target_doc_id}', + {'doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id}, ) source_snapshot = self._get_parse_snapshot(item.doc_id, item.source_kb_id, item.source_algo_id) if source_snapshot is None or source_snapshot.get('status') != DocStatus.SUCCESS.value: @@ -1168,19 +1194,27 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An 'status': source_snapshot.get('status') if source_snapshot else None, }, ) + source_metadata = _from_json(doc.get('meta')) + target_metadata = _merge_transfer_metadata(source_metadata, item.target_metadata) + target_path = _resolve_transfer_target_path(doc.get('path'), item.target_filename, item.target_file_path) + target_filename = os.path.basename(target_path) prepared_items.append({ 'doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, 'source_kb_id': item.source_kb_id, 'source_algo_id': item.source_algo_id, 'target_kb_id': item.target_kb_id, 'target_algo_id': item.target_algo_id, 'mode': item.mode, 'file_path': doc.get('path'), - 'metadata': _from_json(doc.get('meta')), + 'target_file_path': target_path, + 'target_filename': target_filename, + 'metadata': target_metadata, + 'source_type': SourceType(doc.get('source_type')), 'transfer_params': { 'mode': 'mv' if item.mode == 'move' else 'cp', 'target_algo_id': item.target_algo_id, - 'target_doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id, }, }) @@ -1327,41 +1361,59 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: for item in prepared_items: task_id = None try: - self._ensure_kb_document(item['target_kb_id'], item['doc_id']) + self._upsert_doc( + doc_id=item['target_doc_id'], + filename=item['target_filename'], + path=item['target_file_path'], + metadata=item['metadata'], + source_type=item['source_type'], + upload_status=DocStatus.SUCCESS, + allowed_path_doc_ids={item['doc_id']}, + ) + self._ensure_kb_document(item['target_kb_id'], item['target_doc_id']) task_id, snapshot = self._enqueue_task( - item['doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, idempotency_key=request.idempotency_key, file_path=item['file_path'], metadata=item['metadata'], parser_kb_id=item['source_kb_id'], transfer_params=item['transfer_params'], extra_message={ + 'source_doc_id': item['doc_id'], 'source_kb_id': item['source_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_kb_id': item['target_kb_id'], 'target_algo_id': item['target_algo_id'], + 'target_doc_id': item['target_doc_id'], 'mode': item['mode'], }, + parser_doc_id=item['doc_id'], ) error_code = None error_msg = None accepted = True except Exception as exc: - snapshot = self._get_parse_snapshot(item['doc_id'], item['target_kb_id'], item['target_algo_id']) or {} + snapshot = self._get_parse_snapshot( + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'] + ) or {} task_id = task_id or snapshot.get('current_task_id') error_code = snapshot.get('last_error_code') if not error_code: error_code = exc.biz_code if isinstance(exc, DocServiceError) else type(exc).__name__ - error_msg = snapshot.get('last_error_msg') or (exc.msg if isinstance(exc, DocServiceError) else str(exc)) + error_msg = snapshot.get('last_error_msg') or ( + exc.msg if isinstance(exc, DocServiceError) else str(exc) + ) accepted = False items.append({ 'doc_id': item['doc_id'], + 'target_doc_id': item['target_doc_id'], 'task_id': task_id, 'source_kb_id': item['source_kb_id'], 'target_kb_id': item['target_kb_id'], 'source_algo_id': item['source_algo_id'], 'target_algo_id': item['target_algo_id'], 'mode': item['mode'], + 'target_file_path': item['target_file_path'], 'status': snapshot.get('status', DocStatus.FAILED.value), 'accepted': accepted, 'error_code': error_code, @@ -1504,10 +1556,11 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 and task_message.get('mode') == 'move' ): source_kb_id = task_message.get('source_kb_id') + source_doc_id = task_message.get('source_doc_id') or doc_id if source_kb_id and source_kb_id != kb_id: - self._remove_kb_document(source_kb_id, doc_id) - self._delete_parse_snapshots(doc_id, source_kb_id) - self._sync_doc_upload_status(doc_id) + self._remove_kb_document(source_kb_id, source_doc_id) + self._delete_parse_snapshots(source_doc_id, source_kb_id) + self._sync_doc_upload_status(source_doc_id) self._update_task_record( callback.task_id, diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index e6fac297e..40098851b 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -6,6 +6,8 @@ import traceback from typing import Any, Dict, List, Optional +import requests + from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, UrlModule, once_wrapper from lazyllm.thirdparty import fastapi @@ -586,6 +588,11 @@ def health(self): self._lazy_init() return BaseResponse(code=200, msg='success', data=self._manager.health()) + @app.get('/v1/internal/parser-url') + def get_parser_url(self): + self._lazy_init() + return BaseResponse(code=200, msg='success', data={'parser_url': self._parser_url}) + def __call__(self, func_name: str, *args, **kwargs): return getattr(self, func_name)(*args, **kwargs) @@ -691,6 +698,19 @@ def url(self): def _url(self): return self.url + @property + def parser_url(self): + if self._raw_impl: + return self._raw_impl._parser_url + base_url = self.url.rsplit('/', 1)[0] + try: + response = requests.get(f'{base_url}/v1/internal/parser-url', timeout=5) + response.raise_for_status() + return response.json()['data']['parser_url'] + except (requests.RequestException, KeyError, TypeError, ValueError) as exc: + LOG.warning(f'[DocServer] failed to resolve remote parser_url from {base_url}: {exc}') + return None + @staticmethod def _normalize_dispatch_result(result): if isinstance(result, fastapi.responses.JSONResponse): diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 411e646d2..f97e0a6d9 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -1,4 +1,5 @@ import os +import warnings from typing import Callable, Optional, Dict, Union, List, Type, Set, Tuple from functools import cached_property from pydantic import BaseModel @@ -26,6 +27,10 @@ _LOCAL_PYTHONPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) +def _is_local_map_store(store_conf: Optional[Dict]) -> bool: + return isinstance(store_conf, dict) and store_conf.get('type') == 'map' + + class CallableDict(dict): def __call__(self, cls, *args, **kw): return self[cls](*args, **kw) @@ -42,13 +47,14 @@ class _Manager(ModuleBase): DEFAULT_GROUP_NAME = '__default__' def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - manager: Union[bool, str] = False, server: Union[bool, int] = False, name: Optional[str] = None, + manager: Union[bool, str, DocServer] = False, server: Union[bool, int] = False, + name: Optional[str] = None, launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False, doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, - doc_server_port: Optional[int] = None, + create_ui: bool = False, enable_path_monitoring: Optional[bool] = None): super().__init__() self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud @@ -64,6 +70,32 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._dataset_path = dataset_path self._embed = self._get_embeds(embed) self._processor = processor + compat_ui_manager = manager == 'ui' + if compat_ui_manager: + lazyllm.LOG.warning('`manager=\'ui\'` is deprecated, use `manager=True, create_ui=True` instead') + elif isinstance(manager, str): + raise ValueError(f'Unsupported manager value: {manager}') + spawn_doc_server = bool(manager) and not isinstance(manager, DocServer) + connect_doc_server = isinstance(manager, DocServer) + doc_impl_dataset_path = dataset_path if not (spawn_doc_server or connect_doc_server) else None + self._doc_impl_dataset_path = doc_impl_dataset_path + self._doc_processor = None + if spawn_doc_server: + self._doc_processor = DocumentProcessor(launcher=self._launcher, pythonpath=_LOCAL_PYTHONPATH) + self._doc_processor.start() + self._manager = DocServer( + launcher=self._launcher, + storage_dir=dataset_path, + parser_url=self._doc_processor.url, + pythonpath=_LOCAL_PYTHONPATH, + ) + elif connect_doc_server: + self._manager = manager + parser_url = getattr(getattr(manager, '_raw_impl', None), '_parser_url', None) + if parser_url is None: + parser_url = manager.parser_url + if parser_url: + self._doc_processor = DocumentProcessor(url=parser_url) self._schema_extractor = self._register_submodules(schema_extractor) self._store_conf = store_conf self._display_name = display_name @@ -71,28 +103,20 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, name = name or self.DEFAULT_GROUP_NAME if not display_name: display_name = name if enable_path_monitoring is None: - enable_path_monitoring = False if manager else True + enable_path_monitoring = False if (spawn_doc_server or connect_doc_server or processor) else True self._enable_path_monitoring = enable_path_monitoring + doc_processor = self._doc_processor or processor self._kbs = CallableDict({name: DocImpl( embed=self._embed, - dataset_path=dataset_path, + dataset_path=self._doc_impl_dataset_path, enable_path_monitoring=enable_path_monitoring, doc_files=doc_files, global_metadata_desc=doc_fields, - store=store_conf, processor=processor, algo_name=name, display_name=display_name, + store=store_conf, processor=doc_processor, algo_name=name, display_name=display_name, description=description, schema_extractor=schema_extractor)}) - if manager: - self._doc_processor = DocumentProcessor(launcher=self._launcher, pythonpath=_LOCAL_PYTHONPATH) - self._doc_processor.start() - self._manager = DocServer( - launcher=self._launcher, - storage_dir=dataset_path, - port=doc_server_port, - parser_url=self._doc_processor.url, - pythonpath=_LOCAL_PYTHONPATH, - ) - if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager) + if create_ui: + self.ensure_doc_web() if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server))) self._global_metadata_desc = doc_fields @@ -111,6 +135,17 @@ def web_url(self): if hasattr(self, '_docweb'): return self._docweb.url return None + def ensure_doc_web(self): + if hasattr(self, '_docweb'): + return self._docweb + if not hasattr(self, '_manager') or not isinstance(self._manager, DocServer): + raise ValueError( + '`create_ui=True` requires an available DocServer. ' + 'Set `manager=True` or pass `manager=DocServer(...)`.' + ) + self._docweb = DocWebModule(doc_server=self._manager) + return self._docweb + def _get_embeds(self, embed): embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {} return self._register_submodules(embeds) @@ -128,13 +163,13 @@ def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, embed = self._get_embeds(embed) if embed else self._embed schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor impl = DocImpl( - dataset_path=self._dataset_path, + dataset_path=self._doc_impl_dataset_path, embed=embed, kb_group_name=name, enable_path_monitoring=self._enable_path_monitoring, global_metadata_desc=doc_fields, store=store_conf or self._store_conf, - processor=self._processor, + processor=self._doc_processor or self._processor, algo_name=name, display_name=name, description=self._description, @@ -166,17 +201,27 @@ def __new__(cls, *args, **kw): return super().__new__(cls) def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - create_ui: bool = False, manager: Union[bool, str, 'Document._Manager', DocumentProcessor] = False, - server: Union[bool, int] = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, - doc_files: Optional[List[str]] = None, doc_fields: Dict[str, DocField] = None, + create_ui: bool = False, + manager: Union[bool, str, DocServer, 'Document._Manager', DocumentProcessor] = False, + server: Union[bool, int] = False, name: Optional[str] = None, + launcher: Optional[Launcher] = None, doc_files: Optional[List[str]] = None, + doc_fields: Dict[str, DocField] = None, store_conf: Optional[Dict] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, - doc_server_port: Optional[int] = None, enable_path_monitoring: Optional[bool] = None): + enable_path_monitoring: Optional[bool] = None): super().__init__() - if create_ui: - lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') - manager = create_ui + if isinstance(manager, str): + if manager != 'ui': + raise ValueError(f'Unsupported manager value: {manager}') + warnings.warn( + '`manager="ui"` is deprecated, use `manager=True, create_ui=True` instead.', + DeprecationWarning, + stacklevel=2, + ) + lazyllm.LOG.warning('`manager="ui"` is deprecated, use `manager=True, create_ui=True` instead') + create_ui = True + manager = True if isinstance(dataset_path, (tuple, list)): doc_fields = dataset_path dataset_path = None @@ -198,23 +243,26 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal f'while received `{dataset_path}`') manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed, schema_extractor=schema_extractor) + if create_ui: + manager.ensure_doc_web() self._manager = manager self._curr_group = name else: if isinstance(manager, DocumentProcessor): + if store_conf is None: + raise ValueError('`store_conf` is required when `manager` is a DocumentProcessor') + if _is_local_map_store(store_conf): + raise ValueError('`manager=DocumentProcessor(...)` does not support pure local map store') processor, cloud = manager, True processor.start() manager = False - assert name, '`Name` of Document is necessary when using cloud service' - assert store_conf.get('type') != 'map', 'Cloud manager is not supported when using map store' - assert not dataset_path, 'Cloud manager is not supported with local dataset path' else: cloud, processor = False, None self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf, doc_fields, cloud=cloud, doc_files=doc_files, processor=processor, display_name=display_name, description=description, schema_extractor=schema_extractor, - doc_server_port=doc_server_port, + create_ui=create_ui, enable_path_monitoring=enable_path_monitoring) self._curr_group = name self._doc_to_db_processor: DocToDbProcessor = None diff --git a/tests/basic_tests/RAG/test_doc_service_doc_server.py b/tests/basic_tests/RAG/test_doc_service_doc_server.py index 5ecff01a4..5297bb463 100644 --- a/tests/basic_tests/RAG/test_doc_service_doc_server.py +++ b/tests/basic_tests/RAG/test_doc_service_doc_server.py @@ -4,6 +4,7 @@ import tempfile import pytest +import requests from lazyllm.thirdparty import fastapi @@ -153,6 +154,19 @@ def test_run_wraps_doc_service_error(): assert body['data']['x'] == 1 +def test_parser_url_returns_none_when_remote_endpoint_unavailable(monkeypatch): + server = object.__new__(DocServer) + server._raw_impl = None + server._impl = type('FakeImpl', (), {'_url': 'http://127.0.0.1:19002/generate'})() + + def _raise(*args, **kwargs): + raise requests.ConnectionError('missing endpoint') + + monkeypatch.setattr(requests, 'get', _raise) + + assert server.parser_url is None + + def test_cancel_task_http_requires_task_id(server_impl): with pytest.raises(fastapi.HTTPException) as exc: asyncio.run(server_impl.cancel_task(_JsonRequest({}))) diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py index f134c0649..7a3e67b34 100644 --- a/tests/basic_tests/RAG/test_doc_service_mock.py +++ b/tests/basic_tests/RAG/test_doc_service_mock.py @@ -238,25 +238,6 @@ def test_p0_endpoints_and_core_flows(self): kb_delete = requests.delete(f'{self.base_url}/v1/kbs/kb_a', timeout=8) assert kb_delete.status_code == 200 - def test_document_manager_supports_doc_server_port(self): - from lazyllm import Document - - tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_port_') - storage_dir = os.path.join(tmp_dir, 'uploads') - os.makedirs(storage_dir, exist_ok=True) - fixed_port = 18898 - doc = Document(dataset_path=storage_dir, manager=True, doc_server_port=fixed_port, name='doc_port_test') - try: - self._ensure_bindable() - doc.start() - base_url = doc.manager.url.rsplit('/', 1)[0] - assert base_url.endswith(f':{fixed_port}') - health = requests.get(f'{base_url}/v1/health', timeout=5) - assert health.status_code == 200 - finally: - doc.stop() - shutil.rmtree(tmp_dir, ignore_errors=True) - def test_missing_p0_endpoints_exist(self): kb_create = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_endpoints'}, timeout=5) assert kb_create.status_code == 200 diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index be598f006..1d279cdd1 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -220,6 +220,176 @@ def test_create_document(self): Document(dataset_path) Document(dataset_path + os.sep) + def test_dataset_path_enables_monitoring_by_default_without_manager(self): + doc = Document(self._build_dataset()) + assert doc._manager._enable_path_monitoring is True + assert doc._impl._dataset_path == doc._manager._origin_path + + def test_manager_true_disables_monitoring_and_creates_ui(self): + dataset_path = self._build_dataset() + calls = {} + + class FakeDocumentProcessor: + def __init__(self, *args, **kwargs): + self.url = 'http://127.0.0.1:19001/generate' + + def start(self): + calls['processor_started'] = True + + class FakeDocServer: + def __init__(self, *args, **kwargs): + calls['storage_dir'] = kwargs.get('storage_dir') + self._url = 'http://127.0.0.1:19002/generate' + + class FakeDocWebModule: + def __init__(self, doc_server, *args, **kwargs): + calls['web_doc_server'] = doc_server + self.url = 'http://127.0.0.1:19003' + + def stop(self): + return None + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ + patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ + patch('lazyllm.tools.rag.document.DocWebModule', FakeDocWebModule): + doc = Document(dataset_path, manager=True, create_ui=True) + try: + assert calls['processor_started'] is True + assert calls['storage_dir'] == doc._manager._origin_path + assert calls['web_doc_server'] is doc._manager._manager + assert doc._manager._enable_path_monitoring is False + assert doc._impl._dataset_path == doc._manager._dataset_path + finally: + doc.stop() + + def test_doc_server_manager_disables_monitoring_without_local_path_following(self): + dataset_path = self._build_dataset() + calls = {} + + class FakeDocumentProcessor: + def __init__(self, *args, **kwargs): + calls['parser_url'] = kwargs.get('url') + + def start(self): + return None + + class FakeDocWebModule: + def __init__(self, doc_server, *args, **kwargs): + calls['web_doc_server'] = doc_server + self.url = 'http://127.0.0.1:19003' + + def stop(self): + return None + + class FakeDocServer: + def __init__(self): + self.parser_url = 'http://127.0.0.1:19011/generate' + self._url = 'http://127.0.0.1:19012/generate' + + external_doc_server = FakeDocServer() + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ + patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ + patch('lazyllm.tools.rag.document.DocWebModule', FakeDocWebModule): + doc = Document(dataset_path, manager=external_doc_server, create_ui=True) + try: + assert calls['parser_url'] == external_doc_server.parser_url + assert calls['web_doc_server'] is external_doc_server + assert doc._manager._enable_path_monitoring is False + assert doc._manager._dataset_path == doc._manager._origin_path + assert doc._impl._dataset_path is None + finally: + doc.stop() + + def test_document_processor_manager_requires_store_conf_and_disables_monitoring(self): + class FakeDocumentProcessor: + def __init__(self): + self.start_calls = 0 + + def start(self): + self.start_calls += 1 + + dataset_path = self._build_dataset() + processor = FakeDocumentProcessor() + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor): + with self.assertRaises(ValueError): + Document(dataset_path, manager=processor) + assert processor.start_calls == 0 + + with self.assertRaises(ValueError): + Document(dataset_path, manager=processor, store_conf={'type': 'map'}) + assert processor.start_calls == 0 + + doc = Document(dataset_path, manager=processor, store_conf={'type': 'milvus'}) + + assert processor.start_calls == 1 + assert doc._manager._enable_path_monitoring is False + assert doc._impl._dataset_path == doc._manager._origin_path + + def test_managed_document_keeps_local_dataset_path_for_helper_apis(self): + dataset_path = self._build_dataset() + + class FakeDocumentProcessor: + def __init__(self, *args, **kwargs): + self.url = 'http://127.0.0.1:19001/generate' + + def start(self): + return None + + class FakeDocServer: + def __init__(self, *args, **kwargs): + self._url = 'http://127.0.0.1:19002/generate' + + class FakeGraphRagServerModule: + def __init__(self, kg_dir): + self.kg_dir = kg_dir + + def stop(self): + return None + + expected_files = [os.path.join(dataset_path, 'rag.txt')] + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ + patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ + patch('lazyllm.tools.rag.document.extract_db_schema_from_files', return_value=[]) as extract_mock, \ + patch('lazyllm.tools.rag.graph_document.GraphRagServerModule', FakeGraphRagServerModule): + doc = Document(dataset_path, manager=True) + try: + from lazyllm.tools.rag.graph_document import GraphDocument + + graph_doc = GraphDocument(doc) + assert doc._manager._dataset_path == doc._manager._origin_path + assert doc._impl._dataset_path is None + doc.extract_db_schema(MagicMock()) + extract_mock.assert_called_once_with(expected_files, unittest.mock.ANY) + assert graph_doc._kg_dir == os.path.join(dataset_path, '.graphrag_kg') + finally: + doc.stop() + + def test_remote_doc_server_manager_allows_missing_parser_url(self): + dataset_path = self._build_dataset() + + class FakeDocServer: + def __init__(self, *args, **kwargs): + self._raw_impl = None + self._url = 'http://127.0.0.1:19002/generate' + + @property + def parser_url(self): + return None + + with patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ + patch('lazyllm.tools.rag.document.DocumentProcessor') as processor_cls: + doc_server = FakeDocServer() + doc = Document(dataset_path, manager=doc_server) + try: + processor_cls.assert_not_called() + assert doc._manager._dataset_path == doc._manager._origin_path + assert doc._impl._dataset_path is None + finally: + doc.stop() + def test_register_with_pattern(self): Document.create_node_group('AdaptiveChunk1', transform=[ TransformArgs(f=SentenceSplitter, pattern='*.txt', kwargs=dict(chunk_size=512, chunk_overlap=50)), @@ -308,7 +478,7 @@ def _test_impl(group, target): def test_doc_web_module(self): dataset_path = self._build_dataset() - doc = Document(dataset_path, manager='ui') + doc = Document(dataset_path, manager=True, create_ui=True) try: doc.create_kb_group(name='test_group') doc2 = Document(dataset_path, manager=doc.manager, name='test_group2') @@ -318,6 +488,15 @@ def test_doc_web_module(self): finally: doc.stop() + def test_manager_ui_remains_compatible(self): + dataset_path = self._build_dataset() + doc = Document(dataset_path, manager='ui') + try: + assert hasattr(doc._manager, '_docweb') + assert doc._manager._enable_path_monitoring is False + finally: + doc.stop() + def test_doc_web_module_uses_workspace_pythonpath(self): dataset_path = self._build_dataset() calls = {} @@ -338,7 +517,7 @@ def __init__(self, *args, **kwargs): with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): - doc = Document(dataset_path, manager='ui') + doc = Document(dataset_path, manager=True, create_ui=True) try: assert calls['processor_started'] is True assert calls['processor_pythonpath'] == document_module._LOCAL_PYTHONPATH @@ -347,6 +526,92 @@ def __init__(self, *args, **kwargs): finally: doc.stop() + def test_doc_web_module_registers_algorithms_with_spawned_processor(self): + dataset_path = self._build_dataset() + calls = {'registered_algorithms': [], 'add_doc_calls': []} + + class FakeDocumentProcessor: + def __init__(self, *args, **kwargs): + self.url = 'http://127.0.0.1:19001/generate' + + def start(self): + return None + + def register_algorithm(self, name, *args, **kwargs): + calls['registered_algorithms'].append(name) + + def add_doc(self, input_files, ids, metadatas=None, **kwargs): + calls['add_doc_calls'].append({ + 'input_files': input_files, + 'ids': ids, + 'metadatas': metadatas, + }) + + class FakeDocServer: + def __init__(self, *args, **kwargs): + self._url = 'http://127.0.0.1:19002/generate' + + with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ + patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): + doc = Document(dataset_path, manager=True, create_ui=True) + try: + doc._impl._lazy_init() + doc2 = doc.create_kb_group(name='test_group') + doc2._impl._lazy_init() + + assert calls['registered_algorithms'] == ['__default__', 'test_group'] + assert len(calls['add_doc_calls']) == 0 + finally: + doc.stop() + + def test_create_ui_requires_doc_server(self): + with self.assertRaisesRegex(ValueError, 'requires an available DocServer'): + Document(self._build_dataset(), create_ui=True) + + def test_remote_doc_server_manager_disables_local_path_follow(self): + dataset_path = self._build_dataset() + + class FakeDocServer: + def __init__(self, *args, **kwargs): + self._raw_impl = None + self._url = 'http://127.0.0.1:19002/generate' + + @property + def parser_url(self): + return None + + with patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): + doc_server = FakeDocServer() + doc = Document(dataset_path, manager=doc_server) + try: + assert doc._manager._enable_path_monitoring is False + assert doc._manager._dataset_path == doc._manager._origin_path + assert doc._impl._enable_path_monitoring is False + assert doc._impl._dataset_path is None + finally: + doc.stop() + + def test_document_processor_manager_constraints(self): + dataset_path = self._build_dataset() + processor = document_module.DocumentProcessor(url='http://127.0.0.1:9966') + + with self.assertRaisesRegex(ValueError, 'store_conf'): + Document(dataset_path, manager=processor) + with self.assertRaisesRegex(ValueError, 'pure local map store'): + Document(dataset_path, manager=processor, store_conf={'type': 'map'}) + + doc = Document( + dataset_path, + manager=processor, + store_conf={'type': 'milvus', 'kwargs': {'uri': 'http://localhost:19530'}}, + ) + try: + assert doc._manager._enable_path_monitoring is False + assert doc._impl._enable_path_monitoring is False + assert doc._impl._dataset_path == doc._manager._origin_path + finally: + doc.stop() + def test_get_nodes(self): doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, From 9fe7e1a04ea2b84430f4a9d46ebfe47ae9f55f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 7 Apr 2026 16:20:59 +0800 Subject: [PATCH 39/46] Export DocServer and trim docservice docs --- lazyllm/docs/tools/tool_rag.py | 56 ---------------------------------- lazyllm/tools/__init__.py | 3 +- lazyllm/tools/rag/__init__.py | 2 ++ 3 files changed, 4 insertions(+), 57 deletions(-) diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index 1a0e7698c..3e395c680 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -11,12 +11,6 @@ add_doc_service_english_doc = functools.partial( utils.add_english_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service') ) -add_doc_service_base_chinese_doc = functools.partial( - utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service.base') -) -add_doc_service_base_english_doc = functools.partial( - utils.add_english_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service.base') -) add_english_doc('Document', '''\ Initialize a document management module with optional embedding, storage, and user interface. @@ -146,56 +140,6 @@ 当前不支持跨算法 transfer。可选字段 ``target_filename`` 与 ``target_file_path`` 用于覆盖目标文档记录的文件名或文件路径。 ''') -add_doc_service_base_english_doc('TransferItem', '''\ -Single item in a document transfer request. - -Args: - doc_id (str): Source document ID. - target_doc_id (str): Required destination document ID. Must be unique in the target knowledge base. - source_kb_id (str): Source knowledge-base ID. - source_algo_id (str): Source algorithm ID. - target_kb_id (str): Destination knowledge-base ID. - target_algo_id (str): Destination algorithm ID. - target_metadata (Optional[Dict[str, Any]]): Metadata patch applied on top of the source document metadata for the transferred target document. - target_filename (Optional[str]): Target file name override. - target_file_path (Optional[str]): Target file path override. If set together with ``target_filename``, both must - point to the same basename. - mode (str): Transfer mode. Supports ``copy`` and ``move``. -''') - -add_doc_service_base_chinese_doc('TransferItem', '''\ -文档转移请求中的单个条目。 - -Args: - doc_id (str): 源文档 ID。 - target_doc_id (str): 必填的目标文档 ID,在目标知识库中必须唯一。 - source_kb_id (str): 源知识库 ID。 - source_algo_id (str): 源算法 ID。 - target_kb_id (str): 目标知识库 ID。 - target_algo_id (str): 目标算法 ID。 - target_metadata (Optional[Dict[str, Any]]): 基于源文档 metadata 做继承后,再覆盖写入目标文档的 metadata patch。 - target_filename (Optional[str]): 目标文件名覆盖值。 - target_file_path (Optional[str]): 目标文件路径覆盖值;若与 ``target_filename`` 同时传入,二者 basename - 必须一致。 - mode (str): 转移模式,支持 ``copy`` 与 ``move``。 -''') - -add_doc_service_base_english_doc('TransferRequest', '''\ -Batch transfer request for ``DocServer.transfer``. - -Args: - items (List[TransferItem]): Transfer items to execute. - idempotency_key (Optional[str]): Optional idempotency key for safe retries. -''') - -add_doc_service_base_chinese_doc('TransferRequest', '''\ -``DocServer.transfer`` 使用的批量转移请求。 - -Args: - items (List[TransferItem]): 要执行的转移条目列表。 - idempotency_key (Optional[str]): 可选幂等键,用于安全重试。 -''') - add_example('Document', '''\ >>> import lazyllm >>> from lazyllm.tools import Document diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 9dd625ac9..62da11d3c 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: # flake8: noqa: E401 from .rag import (Document, GraphDocument, UrlGraphDocument, Reranker, Retriever, TempDocRetriever, - GraphRetriever, SentenceSplitter, LLMParser) + GraphRetriever, SentenceSplitter, LLMParser, DocServer) from .webpages import WebModule from .fs import (LazyLLMFSBase, CloudFSBufferedFile, CloudFS, CloudFsWatchdog, FeishuFS, ConfluenceFS, NotionFS, GoogleDriveFS, OneDriveFS, YuqueFS, OnesFS, S3FS, @@ -55,6 +55,7 @@ def __getattr__(name: str): _SUBMOD_MAP = { 'rag': [ 'Document', + 'DocServer', 'GraphDocument', 'UrlGraphDocument', 'Reranker', diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index a99670771..03f3039df 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -5,6 +5,7 @@ # flake8: noqa: E402 from .document import Document from .doc_manager import DocManager +from .doc_service import DocServer from .graph_document import GraphDocument, UrlGraphDocument from .retriever import Retriever, TempDocRetriever, ContextRetriever, WeightedRetriever, PriorityRetriever from .graph_retriever import GraphRetriever @@ -31,6 +32,7 @@ 'add_post_action_for_default_reader', 'Document', 'DocManager', + 'DocServer', 'GraphDocument', 'UrlGraphDocument', 'Reranker', From 49535df2f69697cb5298261ccd2f2bdbdbb37bc6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Tue, 7 Apr 2026 16:32:08 +0800 Subject: [PATCH 40/46] Unify DocServer doc registration --- lazyllm/docs/tools/tool_rag.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index 3e395c680..1f9ebad9d 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -5,12 +5,6 @@ add_chinese_doc = functools.partial(utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools')) add_english_doc = functools.partial(utils.add_english_doc, module=importlib.import_module('lazyllm.tools')) add_example = functools.partial(utils.add_example, module=importlib.import_module('lazyllm.tools')) -add_doc_service_chinese_doc = functools.partial( - utils.add_chinese_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service') -) -add_doc_service_english_doc = functools.partial( - utils.add_english_doc, module=importlib.import_module('lazyllm.tools.rag.doc_service') -) add_english_doc('Document', '''\ Initialize a document management module with optional embedding, storage, and user interface. @@ -56,7 +50,7 @@ enable_path_monitoring (Optional[bool]): 是否监控本地数据目录的文件新增和删除。仅在未接入 ``DocServer`` / ``DocumentProcessor`` 的本地模式下默认开启。 ''') -add_doc_service_english_doc('DocServer', '''\ +add_english_doc('DocServer', '''\ Primary entry point of the refactored document service. ``DocServer`` manages document upload/add/reparse/delete flows, task tracking, knowledge-base management, @@ -75,7 +69,7 @@ launcher: Launcher used to start local services. ''') -add_doc_service_chinese_doc('DocServer', '''\ +add_chinese_doc('DocServer', '''\ 重构后文档服务的主入口。 ``DocServer`` 负责文档上传/添加/重解析/删除、任务跟踪、知识库管理、chunk 查看,以及跨知识库文档转移。 @@ -93,7 +87,7 @@ launcher: 本地服务启动器。 ''') -add_doc_service_english_doc('DocServer.list_chunks', '''\ +add_english_doc('DocServer.list_chunks', '''\ List parsed chunks for a document through the ``/v1/chunks`` endpoint. Args: @@ -109,7 +103,7 @@ Paginated chunk data including ``items`` and ``total``. ''') -add_doc_service_chinese_doc('DocServer.list_chunks', '''\ +add_chinese_doc('DocServer.list_chunks', '''\ 通过 ``/v1/chunks`` 接口分页查看文档的解析 chunk。 Args: @@ -125,7 +119,7 @@ 包含 ``items`` 与 ``total`` 的分页结果。 ''') -add_doc_service_english_doc('DocServer.transfer', '''\ +add_english_doc('DocServer.transfer', '''\ Transfer parsed documents between knowledge bases under the same algorithm. The request body is a ``TransferRequest``. Each transfer item must provide a unique ``target_doc_id`` in the target @@ -133,7 +127,7 @@ ``target_file_path`` can override the destination file name/path recorded for the transferred document. ''') -add_doc_service_chinese_doc('DocServer.transfer', '''\ +add_chinese_doc('DocServer.transfer', '''\ 在同一算法下的不同知识库之间转移已解析文档。 请求体为 ``TransferRequest``。每个转移项都必须在目标知识库中提供唯一的 ``target_doc_id``。 From 473f5e2aa6e39dcbba1f4953011cbf8aa71b816d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 8 Apr 2026 10:02:51 +0800 Subject: [PATCH 41/46] Address doc_service review feedback --- examples/rag/doc_service_mock_example.py | 157 +- examples/rag/doc_service_standalone.py | 239 +- lazyllm/components/deploy/relay/server.py | 19 +- lazyllm/docs/tools/tool_rag.py | 128 +- lazyllm/tools/rag/doc_service/base.py | 129 +- lazyllm/tools/rag/doc_service/doc_manager.py | 318 +-- .../rag/doc_service/doc_server.openapi.json | 2078 +++++++++++++++++ lazyllm/tools/rag/doc_service/doc_server.py | 221 +- .../tools/rag/doc_service/parser_client.py | 119 + lazyllm/tools/rag/doc_service/utils.py | 70 + lazyllm/tools/rag/parsing_service/base.py | 5 +- lazyllm/tools/rag/parsing_service/queue.py | 2 + lazyllm/tools/rag/parsing_service/server.py | 14 +- .../RAG/test_doc_service_doc_server.py | 54 +- 14 files changed, 2876 insertions(+), 677 deletions(-) create mode 100644 lazyllm/tools/rag/doc_service/doc_server.openapi.json create mode 100644 lazyllm/tools/rag/doc_service/parser_client.py create mode 100644 lazyllm/tools/rag/doc_service/utils.py diff --git a/examples/rag/doc_service_mock_example.py b/examples/rag/doc_service_mock_example.py index 0329cc588..382652407 100644 --- a/examples/rag/doc_service_mock_example.py +++ b/examples/rag/doc_service_mock_example.py @@ -1,125 +1,84 @@ -'''DocService mock quickstart. +'''Connect a Document to a deployed DocServer. -Run: - python examples/rag/doc_service_mock_example.py +Start DocServer first: + python examples/rag/doc_service_standalone.py --wait + +Run this example: + python examples/rag/doc_service_mock_example.py --doc-server-url http://127.0.0.1:8848 ''' from __future__ import annotations import argparse -import io import os import tempfile import time -import requests - from lazyllm import Document +from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.base import AddFileItem, AddRequest + + +def _normalize_base_url(url: str) -> str: + url = url.rstrip('/') + if url.endswith('/_call') or url.endswith('/generate'): + return url.rsplit('/', 1)[0] + return url -def _wait_task(base_url: str, task_id: str, targets: set[str], timeout: float = 10.0): - start = time.time() - while time.time() - start < timeout: - resp = requests.get(f'{base_url}/v1/tasks/{task_id}', timeout=5) - resp.raise_for_status() - task = resp.json()['data'] - if task['status'] in targets: +def _wait_task(server: DocServer, task_id: str, timeout: float = 30.0): + deadline = time.time() + timeout + while time.time() < deadline: + task = server.get_task(task_id)['data'] + if task['status'] in {'SUCCESS', 'FAILED', 'CANCELED', 'DELETED'}: return task - time.sleep(0.2) - raise TimeoutError(f'task {task_id} did not reach {targets}') + time.sleep(0.5) + raise TimeoutError(f'task {task_id} did not finish in time') def main(): - parser = argparse.ArgumentParser(description='DocService mock quickstart.') - parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API/docs inspection.') + parser = argparse.ArgumentParser(description='Connect a Document to an existing DocServer.') + parser.add_argument('--doc-server-url', type=str, required=True, help='Existing DocServer base URL.') + parser.add_argument('--algo-id', type=str, default='doc_service_demo_algo', help='Document algorithm ID.') + parser.add_argument('--kb-id', type=str, default='__default__', help='Knowledge base ID.') args = parser.parse_args() - with tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_demo_') as tmp: - storage = os.path.join(tmp, 'uploads') - os.makedirs(storage, exist_ok=True) - seed_path = os.path.join(storage, 'seed.txt') - with open(seed_path, 'w', encoding='utf-8') as f: - f.write('seed content') + doc_server = DocServer(url=_normalize_base_url(args.doc_server_url)) + with tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_example_') as dataset_dir: + file_path = os.path.join(dataset_dir, 'demo.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write('hello from a real doc_service example\n') - doc = Document( - dataset_path=storage, - manager=True, - name='demo_doc_service', - ) - doc.start() + # Step 1: create a Document and bind it to the deployed DocServer. + document = Document(dataset_path=dataset_dir, manager=doc_server, name=args.algo_id) + document.start() try: - base_url = doc.manager.url.rsplit('/', 1)[0] - print(f'DocService URL: {base_url}') - print(f'Swagger Docs: {base_url}/docs') - - upload_resp = requests.post( - f'{base_url}/v1/docs/upload', - params={'kb_id': 'kb_demo', 'algo_id': '__default__'}, - files=[('files', ('demo.txt', io.BytesIO(b'hello lazyllm rag'), 'text/plain'))], - timeout=10, - ) - upload_resp.raise_for_status() - upload_item = upload_resp.json()['data']['items'][0] - doc_id = upload_item['doc_id'] - task_id = upload_item['task_id'] - _wait_task(base_url, task_id, {'SUCCESS'}) - - patch_resp = requests.post( - f'{base_url}/v1/docs/metadata/patch', - json={ - 'kb_id': 'kb_demo', - 'algo_id': '__default__', - 'items': [{'doc_id': doc_id, 'patch': {'owner': 'demo_user', 'scene': 'quickstart'}}], - }, - timeout=10, - ) - patch_resp.raise_for_status() - patch_task = patch_resp.json()['data']['items'][0]['task_id'] - _wait_task(base_url, patch_task, {'SUCCESS'}) - - reparse_resp = requests.post( - f'{base_url}/v1/docs/reparse', - json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=10, - ) - reparse_resp.raise_for_status() - reparse_task = reparse_resp.json()['data']['task_ids'][0] - _wait_task(base_url, reparse_task, {'SUCCESS'}) - - add_resp = requests.post( - f'{base_url}/v1/docs/add', - json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'items': [{'file_path': seed_path}]}, - timeout=10, - ) - add_resp.raise_for_status() - add_task = add_resp.json()['data']['items'][0]['task_id'] - _wait_task(base_url, add_task, {'SUCCESS'}) - - docs_resp = requests.get( - f'{base_url}/v1/docs', - params={'kb_id': 'kb_demo', 'include_deleted_or_canceled': False}, - timeout=10, - ) - docs_resp.raise_for_status() - docs = docs_resp.json()['data']['items'] - print(f'Current docs in kb_demo: {len(docs)}') - - delete_resp = requests.post( - f'{base_url}/v1/docs/delete', - json={'kb_id': 'kb_demo', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=10, - ) - delete_resp.raise_for_status() - delete_task = delete_resp.json()['data']['items'][0]['task_id'] - _wait_task(base_url, delete_task, {'DELETED'}) - print('Doc lifecycle demo completed.') - if args.wait: - print('Server is running. Press Ctrl+C to stop...') - while True: - time.sleep(1) + base_url = _normalize_base_url(args.doc_server_url) + print(f'DocServer URL: {base_url}') + print(f'DocServer Docs: {base_url}/docs') + + # Step 2: add a local file through the DocServer client. + response = doc_server.add(AddRequest( + kb_id=args.kb_id, + algo_id=args.algo_id, + items=[AddFileItem(file_path=file_path)], + )) + item = response['data']['items'][0] + print(f'Doc ID: {item["doc_id"]}') + print(f'Task ID: {item["task_id"]}') + + # Step 3: wait for the asynchronous parse task to finish. + task = _wait_task(doc_server, item['task_id']) + print(f'Task Status: {task["status"]}') + + # Step 4: list documents from the target knowledge base. + docs = doc_server.list_docs( + kb_id=args.kb_id, algo_id=args.algo_id, include_deleted_or_canceled=False + )['data']['items'] + print(f'Doc Count In {args.kb_id}: {len(docs)}') finally: - doc.stop() + document.stop() if __name__ == '__main__': diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py index 64a16ce6e..8d98d1f24 100644 --- a/examples/rag/doc_service_standalone.py +++ b/examples/rag/doc_service_standalone.py @@ -1,21 +1,15 @@ -'''Start standalone DocService. +'''Start a standalone DocServer example. -Modes: -1. Full stack mode (default): starts algorithm registration, real - DocumentProcessor, and DocServer in one process. -2. External parser mode: starts only DocServer and connects to an existing - parsing service with ``--parser-url``. - -Run: +Examples: python examples/rag/doc_service_standalone.py --wait python examples/rag/doc_service_standalone.py --parser-url http://127.0.0.1:9966 --wait + python examples/rag/doc_service_standalone.py --export-openapi ''' from __future__ import annotations import argparse import os -import threading import time from typing import Any, Dict @@ -23,22 +17,16 @@ from lazyllm import Document from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.doc_server import DEFAULT_OPENAPI_OUTPUT_PATH from lazyllm.tools.rag.parsing_service import DocumentProcessor REAL_ALGO_ID = 'real-standalone-algo' FIXED_DB_ROOT = './tmp/db' -DEFAULT_OPENAPI_PATH = os.path.join(FIXED_DB_ROOT, 'doc_service.openapi.json') +DEFAULT_OPENAPI_PATH = DEFAULT_OPENAPI_OUTPUT_PATH def _make_db_config(db_name: str) -> Dict[str, Any]: - return { - 'db_type': 'sqlite', - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'db_name': db_name, - } + return {'db_type': 'sqlite', 'user': None, 'password': None, 'host': None, 'port': None, 'db_name': db_name} def _prepare_runtime_paths() -> Dict[str, str]: @@ -47,36 +35,25 @@ def _prepare_runtime_paths() -> Dict[str, str]: 'root_dir': FIXED_DB_ROOT, 'storage_dir': os.path.join(FIXED_DB_ROOT, 'uploads'), 'store_dir': os.path.join(FIXED_DB_ROOT, 'store'), - 'parser_db': os.path.join(FIXED_DB_ROOT, 'parser.sqlite'), 'doc_db': os.path.join(FIXED_DB_ROOT, 'doc_service.sqlite'), + 'parser_db': os.path.join(FIXED_DB_ROOT, 'parser.sqlite'), } os.makedirs(paths['storage_dir'], exist_ok=True) os.makedirs(paths['store_dir'], exist_ok=True) return paths -def _wait_until(predicate, timeout: float = 20.0, interval: float = 0.1): +def _wait_http_ok(url: str, timeout: float = 20.0): deadline = time.time() + timeout - last = None while time.time() < deadline: - last = predicate() - if last: - return last - time.sleep(interval) - raise RuntimeError(f'condition not satisfied before timeout, last={last!r}') - - -def _wait_http_ok(url: str, timeout: float = 20.0): - def _poll(): try: - resp = requests.get(url, timeout=3) - if resp.status_code == 200: - return resp + response = requests.get(url, timeout=3) + if response.status_code == 200: + return except Exception: - return None - return None - - return _wait_until(_poll, timeout=timeout) + pass + time.sleep(0.2) + raise RuntimeError(f'http service is not ready: {url}') def _build_store_conf(root_dir: str) -> Dict[str, Any]: @@ -85,147 +62,137 @@ def _build_store_conf(root_dir: str) -> Dict[str, Any]: open(segment_store_path, 'a', encoding='utf-8').close() open(milvus_store_path, 'a', encoding='utf-8').close() return { - 'segment_store': { - 'type': 'map', - 'kwargs': {'uri': segment_store_path}, - }, + 'segment_store': {'type': 'map', 'kwargs': {'uri': segment_store_path}}, 'vector_store': { 'type': 'milvus', 'kwargs': { 'uri': milvus_store_path, - 'index_kwargs': { - 'index_type': 'FLAT', - 'metric_type': 'COSINE', - }, + 'index_kwargs': {'index_type': 'FLAT', 'metric_type': 'COSINE'}, }, }, } -def _start_full_stack(args): - paths = _prepare_runtime_paths() +def _wait_algo_ready(parser_url: str, algo_id: str, timeout: float = 20.0): + deadline = time.time() + timeout + while time.time() < deadline: + response = requests.get(f'{parser_url}/algo/list', timeout=5) + if response.status_code == 200: + items = response.json().get('data', []) + if any(item.get('algo_id') == algo_id for item in items): + return + time.sleep(0.2) + raise RuntimeError(f'algorithm is not ready: {algo_id}') - parser = DocumentProcessor( - port=args.parser_port, - db_config=_make_db_config(paths['parser_db']), - num_workers=args.num_workers, - ) - parser.start() - parser_base_url = parser._impl._url.rsplit('/', 1)[0] - _wait_http_ok(f'{parser_base_url}/health') - store_conf = _build_store_conf(paths['store_dir']) +def _register_demo_algorithm(parser_url: str, algo_id: str, store_dir: str): + # Step 2: register a real Document algorithm on the parsing service. document = Document( dataset_path=None, - name=args.algo_id, - embed={'vec_dense': lambda x: [1.0, 2.0, 3.0]}, - store_conf=store_conf, + name=algo_id, + embed={'vec_dense': lambda text: [1.0, 2.0, 3.0]}, + store_conf=_build_store_conf(store_dir), display_name='Standalone Real Algo', - manager=DocumentProcessor(url=parser_base_url), + manager=DocumentProcessor(url=parser_url), description='Algorithm registered by standalone doc service example', ) document.create_node_group( name='line', - transform=lambda x: x.split('\n'), + transform=lambda text: text.split('\n'), parent='CoarseChunk', display_name='Line Chunk', ) document.activate_group('CoarseChunk', embed_keys=['vec_dense']) document.activate_group('line', embed_keys=['vec_dense']) document.start() + _wait_algo_ready(parser_url, algo_id) + return document - _wait_until( - lambda: any( - item.get('algo_id') == args.algo_id - for item in requests.get(f'{parser_base_url}/algo/list', timeout=5).json().get('data', []) - ) - ) - - server = DocServer( - storage_dir=paths['storage_dir'], - db_config=_make_db_config(paths['doc_db']), - parser_url=parser_base_url, - port=args.port, - ) - server.start() - base_url = server.url.rsplit('/', 1)[0] - _wait_http_ok(f'{base_url}/v1/health') - - print(f'DocService URL: {base_url}', flush=True) - print(f'DocService Docs: {base_url}/docs', flush=True) - print(f'Parser URL: {parser_base_url}', flush=True) - print(f'Parser Docs: {parser_base_url}/docs', flush=True) - print(f'Algorithm ID: {args.algo_id}', flush=True) - print(f'Storage Dir: {paths["storage_dir"]}', flush=True) - print(f'Store Dir: {paths["store_dir"]}', flush=True) - print(f'Doc DB: {paths["doc_db"]}', flush=True) - print(f'Parser DB: {paths["parser_db"]}', flush=True) - print(f'DB Root: {paths["root_dir"]}', flush=True) - try: - if args.wait: - print('Full stack is running. Press Ctrl+C to stop...', flush=True) - threading.Event().wait() - finally: - server.stop() - try: - parser.drop_algorithm(args.algo_id) - except Exception: - pass - parser.stop() - - -def _start_doc_server_only(args): - paths = _prepare_runtime_paths() - server = DocServer( - storage_dir=paths['storage_dir'], - db_config=_make_db_config(paths['doc_db']), - parser_url=args.parser_url, - port=args.port, - ) - server.start() - base_url = server.url.rsplit('/', 1)[0] - print(f'DocService URL: {base_url}', flush=True) - print(f'DocService Docs: {base_url}/docs', flush=True) - print(f'Parser URL: {args.parser_url}', flush=True) - print(f'Storage Dir: {paths["storage_dir"]}', flush=True) - print(f'Doc DB: {paths["doc_db"]}', flush=True) - print(f'DB Root: {paths["root_dir"]}', flush=True) - - try: - if args.wait: - print('DocService is running. Press Ctrl+C to stop...', flush=True) - while True: - time.sleep(1) - finally: - server.stop() +def _start_local_parser(parser_port: int, parser_db: str): + # Step 1: start a local parsing service. + parser = DocumentProcessor(port=parser_port, db_config=_make_db_config(parser_db), num_workers=1) + parser.start() + parser_url = parser._impl._url.rsplit('/', 1)[0] + _wait_http_ok(f'{parser_url}/health') + return parser, parser_url def main(): - parser = argparse.ArgumentParser(description='Standalone DocService server.') + parser = argparse.ArgumentParser(description='Standalone DocServer example.') parser.add_argument('--port', type=int, default=8848, help='DocServer listen port.') parser.add_argument('--parser-port', type=int, default=9966, help='DocumentProcessor listen port.') parser.add_argument('--parser-url', type=str, default=None, help='Existing parsing service base URL.') - parser.add_argument('--algo-id', type=str, default=REAL_ALGO_ID, help='Algorithm id to register in full stack mode.') - parser.add_argument('--num-workers', type=int, default=1, help='DocumentProcessor worker count.') - parser.add_argument('--wait', action='store_true', help='Keep server alive for manual API inspection.') + parser.add_argument('--algo-id', type=str, default=REAL_ALGO_ID, help='Algorithm ID for the local demo setup.') + parser.add_argument('--wait', action='store_true', help='Keep the example running for manual inspection.') parser.add_argument( '--export-openapi', type=str, + nargs='?', + const=DEFAULT_OPENAPI_PATH, default=None, - help=f'Export current DocService OpenAPI JSON before startup. Default path example: {DEFAULT_OPENAPI_PATH}', + help=f'Export DocServer OpenAPI JSON and exit. Default path: {DEFAULT_OPENAPI_PATH}', ) args = parser.parse_args() if args.export_openapi: - output_path = DocServer.export_openapi(args.export_openapi) - print(f'OpenAPI exported: {output_path}', flush=True) + print(f'OpenAPI exported: {DocServer.export_openapi(args.export_openapi)}', flush=True) return - if args.parser_url: - _start_doc_server_only(args) - else: - _start_full_stack(args) + paths = _prepare_runtime_paths() + parser_server = None + server = None + document = None + parser_url = args.parser_url + + try: + if parser_url: + parser_url = parser_url.rstrip('/') + _wait_http_ok(f'{parser_url}/health') + else: + parser_server, parser_url = _start_local_parser(args.parser_port, paths['parser_db']) + document = _register_demo_algorithm(parser_url, args.algo_id, paths['store_dir']) + + # Step 3: start DocServer and point it to the parsing service. + server = DocServer( + storage_dir=paths['storage_dir'], + db_config=_make_db_config(paths['doc_db']), + parser_url=parser_url, + port=args.port, + ) + server.start() + base_url = server.url.rsplit('/', 1)[0] + _wait_http_ok(f'{base_url}/v1/health') + + print(f'DocServer URL: {base_url}', flush=True) + print(f'DocServer Docs: {base_url}/docs', flush=True) + print(f'Parser URL: {parser_url}', flush=True) + print(f'Storage Dir: {paths["storage_dir"]}', flush=True) + print(f'Doc DB: {paths["doc_db"]}', flush=True) + if document: + print(f'Algorithm ID: {args.algo_id}', flush=True) + + if args.wait: + # Step 4: keep the services alive for manual API testing. + print('Services are running. Press Ctrl+C to stop.', flush=True) + while True: + time.sleep(1) + finally: + if server: + try: + server.stop() + except Exception: + pass + if document: + try: + document.stop() + except Exception: + pass + if parser_server: + try: + parser_server.stop() + except Exception: + pass if __name__ == '__main__': diff --git a/lazyllm/components/deploy/relay/server.py b/lazyllm/components/deploy/relay/server.py index 8bfadc251..b98457ac7 100644 --- a/lazyllm/components/deploy/relay/server.py +++ b/lazyllm/components/deploy/relay/server.py @@ -1,15 +1,15 @@ import argparse -import os -import sys +import asyncio +import codecs import inspect -import traceback -from types import GeneratorType -import time +import os import pickle -import codecs -import asyncio import functools +import sys +import time +import traceback from functools import partial +from types import GeneratorType from typing import Callable @@ -34,9 +34,8 @@ def _inject_pythonpath(argv): from lazyllm.common.utils import str2obj # noqa: E402 import uvicorn # noqa: E402 import lazyllm # noqa: E402 -from lazyllm import kwargs, package, load_obj # noqa: E402 -from lazyllm import FastapiApp, globals # noqa: E402 -from lazyllm.common import _trim_traceback, _register_trim_module # noqa: E402 +from lazyllm import FastapiApp, globals, kwargs, load_obj, package # noqa: E402 +from lazyllm.common import _register_trim_module, _trim_traceback # noqa: E402 from fastapi import FastAPI, Request # noqa: E402 from fastapi.responses import Response, StreamingResponse # noqa: E402 diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index 1f9ebad9d..c7e62c1ae 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -51,7 +51,7 @@ ''') add_english_doc('DocServer', '''\ -Primary entry point of the refactored document service. +Primary entry point of the document service. ``DocServer`` manages document upload/add/reparse/delete flows, task tracking, knowledge-base management, chunk inspection, and cross-kb transfer. It is the recommended replacement for the legacy ``DocManager`` / @@ -70,7 +70,7 @@ ''') add_chinese_doc('DocServer', '''\ -重构后文档服务的主入口。 +文档服务的主入口。 ``DocServer`` 负责文档上传/添加/重解析/删除、任务跟踪、知识库管理、chunk 查看,以及跨知识库文档转移。 它是 legacy ``DocManager`` / ``DocListManager`` API 的推荐替代方案。 @@ -87,6 +87,130 @@ launcher: 本地服务启动器。 ''') +add_english_doc('DocServer.add', '''\ +Add existing local files through the ``/v1/docs/add`` endpoint. + +Use this method when the file paths are already accessible on the DocServer host. The request body is an +``AddRequest`` containing ``kb_id``, ``algo_id``, and ``items``. Each item can provide ``file_path``, +optional ``doc_id``, and optional ``metadata``. + +Returns: + Standard API response. ``data["items"]`` contains the accepted ``doc_id`` and asynchronous ``task_id``. +''') + +add_chinese_doc('DocServer.add', '''\ +通过 ``/v1/docs/add`` 接口添加服务端可直接访问的本地文件。 + +当文件路径已经对 DocServer 所在机器可见时,使用该方法。请求体为 ``AddRequest``,包含 ``kb_id``、``algo_id`` +和 ``items``。每个 item 可提供 ``file_path``,以及可选的 ``doc_id``、``metadata``。 + +Returns: + 标准 API 响应。``data["items"]`` 中包含接受后的 ``doc_id`` 和异步 ``task_id``。 +''') + +add_english_doc('DocServer.upload', '''\ +Upload files into DocServer-managed storage through the ``/v1/docs/upload`` flow. + +Use this method when you want DocServer to manage uploaded copies of the source files. The request body is an +``UploadRequest`` with ``kb_id``, ``algo_id``, and ``items``. Each item uses ``file_path`` as the local source +path and can optionally include ``doc_id`` or ``metadata``. + +Returns: + Standard API response. ``data["items"]`` contains the accepted ``doc_id`` and asynchronous ``task_id``. +''') + +add_chinese_doc('DocServer.upload', '''\ +通过 ``/v1/docs/upload`` 流程将文件上传到 DocServer 管理的存储目录。 + +当你希望由 DocServer 保存上传副本时,使用该方法。请求体为 ``UploadRequest``,包含 ``kb_id``、``algo_id`` +和 ``items``。每个 item 使用 ``file_path`` 作为本地源路径,也可以附带可选的 ``doc_id``、``metadata``。 + +Returns: + 标准 API 响应。``data["items"]`` 中包含接受后的 ``doc_id`` 和异步 ``task_id``。 +''') + +add_english_doc('DocServer.reparse', '''\ +Reparse existing documents through the ``/v1/docs/reparse`` endpoint. + +The request body is a ``ReparseRequest`` with ``kb_id``, ``algo_id``, and ``doc_ids``. Use it after metadata +or parsing configuration changes when you want to enqueue new parse tasks for existing documents. +''') + +add_chinese_doc('DocServer.reparse', '''\ +通过 ``/v1/docs/reparse`` 接口重新解析已有文档。 + +请求体为 ``ReparseRequest``,包含 ``kb_id``、``algo_id`` 和 ``doc_ids``。当元数据或解析配置变更后, +需要为已有文档重新入队解析任务时,可使用该方法。 +''') + +add_english_doc('DocServer.delete', '''\ +Delete documents from a knowledge base through the ``/v1/docs/delete`` endpoint. + +The request body is a ``DeleteRequest`` with ``kb_id``, ``algo_id``, and ``doc_ids``. Deletion is asynchronous, +so the returned ``task_id`` should be tracked through the task APIs when you need final status. +''') + +add_chinese_doc('DocServer.delete', '''\ +通过 ``/v1/docs/delete`` 接口从知识库中删除文档。 + +请求体为 ``DeleteRequest``,包含 ``kb_id``、``algo_id`` 和 ``doc_ids``。删除是异步操作,因此如果需要最终状态, +应继续通过任务接口跟踪返回的 ``task_id``。 +''') + +add_english_doc('DocServer.patch_metadata', '''\ +Patch document metadata through the ``/v1/docs/metadata/patch`` endpoint. + +The request body is a ``MetadataPatchRequest`` with ``kb_id``, ``algo_id``, and ``items``. Each item targets one +document and carries a partial metadata patch in ``patch``. +''') + +add_chinese_doc('DocServer.patch_metadata', '''\ +通过 ``/v1/docs/metadata/patch`` 接口更新文档元数据。 + +请求体为 ``MetadataPatchRequest``,包含 ``kb_id``、``algo_id`` 和 ``items``。每个 item 指向一个文档, +并在 ``patch`` 中携带需要合并的局部元数据。 +''') + +add_english_doc('DocServer.get_task', '''\ +Get one task record through the ``/v1/tasks/{task_id}`` endpoint. + +Args: + task_id (str): Task ID returned by add, upload, reparse, delete, transfer, or metadata patch operations. + +Returns: + Standard API response with the current task status and task payload. +''') + +add_chinese_doc('DocServer.get_task', '''\ +通过 ``/v1/tasks/{task_id}`` 接口获取单个任务记录。 + +Args: + task_id (str): add、upload、reparse、delete、transfer 或 metadata patch 等操作返回的任务 ID。 + +Returns: + 包含当前任务状态和任务负载的标准 API 响应。 +''') + +add_english_doc('DocServer.cancel_task', '''\ +Cancel a waiting task through the ``/v1/tasks/cancel`` endpoint. + +Args: + task_id (str): Task ID to cancel. + +Returns: + Standard API response indicating whether the task was canceled successfully. +''') + +add_chinese_doc('DocServer.cancel_task', '''\ +通过 ``/v1/tasks/cancel`` 接口取消一个处于等待中的任务。 + +Args: + task_id (str): 要取消的任务 ID。 + +Returns: + 表示任务是否取消成功的标准 API 响应。 +''') + add_english_doc('DocServer.list_chunks', '''\ List parsed chunks for a document through the ``/v1/chunks`` endpoint. diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py index 61622bbc5..d3fd5412c 100644 --- a/lazyllm/tools/rag/doc_service/base.py +++ b/lazyllm/tools/rag/doc_service/base.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional from uuid import uuid4 -from pydantic import BaseModel, Field, model_validator +from pydantic import AliasChoices, BaseModel, Field, model_validator from ..parsing_service.base import TaskType @@ -58,17 +58,6 @@ def http_status(self): return BIZ_HTTP_CODE.get(self.biz_code, 500) -class TaskCreateRequest(BaseModel): - task_id: str = Field(default_factory=lambda: str(uuid4())) - task_type: TaskType - doc_id: str - kb_id: str = '__default__' - algo_id: str = '__default__' - metadata: Dict[str, Any] = Field(default_factory=dict) - priority: int = 0 - callback_url: Optional[str] = None - - class TaskCallbackRequest(BaseModel): callback_id: str = Field(default_factory=lambda: str(uuid4())) task_id: str @@ -91,11 +80,11 @@ def validate_file_path(self): return self -class AddRequest(BaseModel): +class DocItemsRequest(BaseModel): items: List[AddFileItem] kb_id: str = '__default__' algo_id: str = '__default__' - source_type: SourceType = SourceType.EXTERNAL + source_type: Optional[SourceType] = None idempotency_key: Optional[str] = None @model_validator(mode='after') @@ -105,21 +94,11 @@ def validate_items(self): return self -class UploadRequest(BaseModel): - items: List[AddFileItem] - kb_id: str = '__default__' - algo_id: str = '__default__' - source_type: SourceType = SourceType.API - idempotency_key: Optional[str] = None - - @model_validator(mode='after') - def validate_items(self): - if not self.items: - raise ValueError('items is required') - return self +AddRequest = DocItemsRequest +UploadRequest = DocItemsRequest -class ReparseRequest(BaseModel): +class _DocMutationRequest(BaseModel): doc_ids: List[str] kb_id: str = '__default__' algo_id: str = '__default__' @@ -132,24 +111,19 @@ def validate_doc_ids(self): return self -class DeleteRequest(BaseModel): - doc_ids: List[str] - kb_id: str = '__default__' - algo_id: str = '__default__' - idempotency_key: Optional[str] = None +class ReparseRequest(_DocMutationRequest): + pass - @model_validator(mode='after') - def validate_doc_ids(self): - if not self.doc_ids: - raise ValueError('doc_ids is required') - return self + +class DeleteRequest(_DocMutationRequest): + pass class TransferItem(BaseModel): doc_id: str target_doc_id: str - source_kb_id: str = '__default__' - source_algo_id: str = '__default__' + kb_id: str = Field(default='__default__', validation_alias=AliasChoices('kb_id', 'source_kb_id')) + algo_id: str = Field(default='__default__', validation_alias=AliasChoices('algo_id', 'source_algo_id')) target_kb_id: str target_algo_id: str target_metadata: Optional[Dict[str, Any]] = None @@ -157,6 +131,26 @@ class TransferItem(BaseModel): target_file_path: Optional[str] = None mode: str = 'copy' + @model_validator(mode='before') + @classmethod + def normalize_source_fields(cls, data): + if not isinstance(data, dict): + return data + normalized = dict(data) + if 'kb_id' not in normalized and 'source_kb_id' in normalized: + normalized['kb_id'] = normalized['source_kb_id'] + if 'algo_id' not in normalized and 'source_algo_id' in normalized: + normalized['algo_id'] = normalized['source_algo_id'] + return normalized + + @property + def source_kb_id(self) -> str: + return self.kb_id + + @property + def source_algo_id(self) -> str: + return self.algo_id + class TransferRequest(BaseModel): items: List[TransferItem] @@ -187,6 +181,55 @@ def validate_items(self): return self +class KbDeleteBatchRequest(BaseModel): + kb_ids: List[str] + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_kb_ids(self): + if not self.kb_ids: + raise ValueError('kb_ids is required') + return self + + +class AlgorithmInfoRequest(BaseModel): + algo_id: str + + +class TaskInfoRequest(BaseModel): + task_id: str + + +class TaskBatchRequest(BaseModel): + task_ids: List[str] + + @model_validator(mode='after') + def validate_task_ids(self): + if not self.task_ids: + raise ValueError('task_ids is required') + return self + + +class TaskCancelRequest(BaseModel): + task_id: str + idempotency_key: Optional[str] = None + + +class TaskCallbackPayload(BaseModel): + callback_id: Optional[str] = None + task_id: Optional[str] = None + event_type: Optional[CallbackEventType] = None + status: Optional[DocStatus] = None + task_status: Optional[DocStatus] = None + error_code: Optional[str] = None + error_msg: Optional[str] = None + task_type: Optional[TaskType] = None + doc_id: Optional[str] = None + kb_id: Optional[str] = None + algo_id: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + + class KbCreateRequest(BaseModel): kb_id: str display_name: Optional[str] = None @@ -251,7 +294,7 @@ def validate_kb_ids(self): DOC_SERVICE_TASKS_TABLE_INFO = { 'name': 'lazyllm_doc_service_tasks', - 'comment': 'Doc service task state table', + 'comment': 'Doc service task history table', 'columns': [ {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, 'comment': 'Auto increment ID'}, @@ -351,7 +394,7 @@ def validate_kb_ids(self): PARSE_STATE_TABLE_INFO = { 'name': 'lazyllm_doc_parse_state', - 'comment': 'Latest parse state table', + 'comment': 'Latest parse state snapshot table', 'columns': [ {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, 'comment': 'Auto increment ID'}, @@ -380,7 +423,3 @@ def validate_kb_ids(self): 'comment': 'Updated time'}, ], } - - -def now_ts() -> datetime: - return datetime.now() diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index a088b5698..05f9603b1 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -1,16 +1,15 @@ from __future__ import annotations -import hashlib +from datetime import datetime, timedelta import json import os -from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Set from uuid import uuid4 -import requests import sqlalchemy from sqlalchemy.exc import IntegrityError +from lazyllm import LOG from lazyllm.thirdparty import fastapi from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict @@ -37,197 +36,20 @@ TransferRequest, UploadRequest, DocStatus, - now_ts, ) -from ..parsing_service.base import ( - AddDocRequest as ParsingAddDocRequest, - CancelTaskRequest as ParsingCancelTaskRequest, - DeleteDocRequest as ParsingDeleteDocRequest, - FileInfo as ParsingFileInfo, - TransferParams as ParsingTransferParams, - UpdateMetaRequest as ParsingUpdateMetaRequest, +from .parser_client import ParserClient +from .utils import ( + from_json, + gen_doc_id, + hash_payload, + merge_transfer_metadata, + resolve_transfer_target_path, + sha256_file, + stable_json, + to_json, ) -def _to_json(data: Optional[Dict[str, Any]]) -> str: - return json.dumps(data or {}, ensure_ascii=False) - - -def _from_json(raw: Optional[str]) -> Dict[str, Any]: - if not raw: - return {} - try: - return json.loads(raw) - except Exception: - return {} - - -def gen_doc_id(file_path: str, doc_id: Optional[str] = None) -> str: - if doc_id: - return doc_id - return hashlib.sha256(file_path.encode()).hexdigest() - - -def _stable_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, sort_keys=True, default=str) - - -def _hash_payload(data: Any) -> str: - return hashlib.sha256(_stable_json(data).encode()).hexdigest() - - -def _sha256_file(file_path: str) -> str: - digest = hashlib.sha256() - with open(file_path, 'rb') as fh: - for chunk in iter(lambda: fh.read(1024 * 1024), b''): - digest.update(chunk) - return digest.hexdigest() - - -def _merge_transfer_metadata( - source_metadata: Dict[str, Any], target_metadata: Optional[Dict[str, Any]] -) -> Dict[str, Any]: - metadata = dict(source_metadata or {}) - if target_metadata: - metadata.update(target_metadata) - return metadata - - -def _resolve_transfer_target_path( - source_path: str, target_filename: Optional[str], target_file_path: Optional[str] -) -> str: - if target_file_path: - return target_file_path - if target_filename: - base_dir = os.path.dirname(source_path) if source_path else '' - return os.path.join(base_dir, target_filename) if base_dir else target_filename - return source_path - - -class _ParserClient: - def __init__(self, parser_url: str): - parser_url = parser_url.rstrip('/') - if parser_url.endswith('/_call') or parser_url.endswith('/generate'): - parser_url = parser_url.rsplit('/', 1)[0] - self._parser_url = parser_url - - def _post(self, path: str, payload: Dict[str, Any]): - url = f'{self._parser_url}{path}' - resp = requests.post(url, json=payload, timeout=8) - if resp.status_code >= 400: - raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') - return resp.json() - - def _get(self, path: str, params: Optional[Dict[str, Any]] = None): - url = f'{self._parser_url}{path}' - resp = requests.get(url, params=params, timeout=8) - if resp.status_code >= 400: - raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') - return resp.json() - - def _delete(self, path: str, payload: Optional[Dict[str, Any]] = None): - url = f'{self._parser_url}{path}' - resp = requests.delete(url, json=payload, timeout=8) - if resp.status_code >= 400: - raise RuntimeError(f'parser http error: {resp.status_code} {resp.text}') - return resp.json() - - def _get_with_fallback(self, paths: List[str], params: Optional[Dict[str, Any]] = None): - last_error = None - for path in paths: - try: - return self._get(path, params=params) - except RuntimeError as exc: - last_error = exc - if '404' not in str(exc): - raise - if last_error is not None: - raise last_error - raise RuntimeError('parser http error: no path provided') - - def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path: str, - metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, - callback_url: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None): - req = ParsingAddDocRequest( - task_id=task_id, - algo_id=algo_id, - kb_id=kb_id, - callback_url=callback_url, - feedback_url=callback_url, - file_infos=[ParsingFileInfo( - file_path=file_path, - doc_id=doc_id, - metadata=metadata or {}, - reparse_group=reparse_group, - transfer_params=( - ParsingTransferParams.model_validate(transfer_params) - if transfer_params is not None else None - ), - )], - ) - data = self._post('/doc/add', req.model_dump(mode='json')) - return BaseResponse.model_validate(data) - - def update_meta(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, - metadata: Optional[Dict[str, Any]] = None, file_path: Optional[str] = None, - callback_url: Optional[str] = None): - req = ParsingUpdateMetaRequest( - task_id=task_id, - algo_id=algo_id, - kb_id=kb_id, - callback_url=callback_url, - feedback_url=callback_url, - file_infos=[ParsingFileInfo(file_path=file_path, doc_id=doc_id, metadata=metadata or {})], - ) - data = self._post('/doc/meta/update', req.model_dump(mode='json')) - return BaseResponse.model_validate(data) - - def delete_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, - callback_url: Optional[str] = None): - req = ParsingDeleteDocRequest( - task_id=task_id, - algo_id=algo_id, - kb_id=kb_id, - doc_ids=[doc_id], - callback_url=callback_url, - feedback_url=callback_url, - ) - data = self._delete('/doc/delete', req.model_dump(mode='json')) - return BaseResponse.model_validate(data) - - def cancel_task(self, task_id: str): - req = ParsingCancelTaskRequest(task_id=task_id) - data = self._post('/doc/cancel', req.model_dump(mode='json')) - return BaseResponse.model_validate(data) - - def list_algorithms(self): - data = self._get_with_fallback(['/v1/algo/list', '/algo/list']) - return BaseResponse.model_validate(data) - - def get_algorithm_groups(self, algo_id: str): - try: - data = self._get_with_fallback([ - f'/v1/algo/{algo_id}/groups', - f'/algo/{algo_id}/group/info', - ]) - return BaseResponse.model_validate(data) - except RuntimeError as exc: - if '404' in str(exc): - return BaseResponse(code=404, msg='algo not found', data=None) - raise - - def list_doc_chunks(self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int, page_size: int): - data = self._get('/doc/chunks', params={ - 'algo_id': algo_id, - 'kb_id': kb_id, - 'doc_id': doc_id, - 'group': group, - 'offset': offset, - 'page_size': page_size, - }) - return BaseResponse.model_validate(data) - - class DocManager: def __init__( self, @@ -247,9 +69,9 @@ def __init__( DOC_SERVICE_TASKS_TABLE_INFO]}, ) self._ensure_indexes() - self._parser_client = _ParserClient(parser_url=parser_url) + self._parser_client = ParserClient(parser_url=parser_url) + self._parser_client.health() self._callback_url = callback_url - self._upsert_default_kb() def set_callback_url(self, callback_url: str): self._callback_url = callback_url @@ -295,7 +117,7 @@ def _ensure_indexes(self): for stmt in stmts: self._db_manager.execute_commit(stmt) - def _upsert_default_kb(self): + def _ensure_default_kb(self): self._ensure_kb('__default__', display_name='__default__') self._ensure_kb_algorithm('__default__', '__default__') self._cleanup_idempotency_records() @@ -303,7 +125,7 @@ def _upsert_default_kb(self): def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, update_fields: Optional[Set[str]] = None): - now = now_ts() + now = datetime.now() with self._db_manager.get_session() as session: Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) row = session.query(Kb).filter(Kb.kb_id == kb_id).first() @@ -315,7 +137,7 @@ def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description doc_count=0, status=KBStatus.ACTIVE.value, owner_id=owner_id, - meta=_to_json(meta), + meta=to_json(meta), created_at=now, updated_at=now, ) @@ -329,14 +151,14 @@ def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description if 'owner_id' in update_fields: row.owner_id = owner_id if 'meta' in update_fields: - row.meta = _to_json(meta) + row.meta = to_json(meta) if row.status == KBStatus.DELETED.value: row.status = KBStatus.ACTIVE.value row.updated_at = now session.add(row) def _ensure_kb_algorithm(self, kb_id: str, algo_id: str): - now = now_ts() + now = datetime.now() with self._db_manager.get_session() as session: Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) row = session.query(Rel).filter(Rel.kb_id == kb_id).first() @@ -372,13 +194,15 @@ def _build_kb_data(kb_row, algo_row=None): 'doc_count': kb_row.doc_count, 'status': kb_row.status, 'owner_id': kb_row.owner_id, - 'meta': _from_json(kb_row.meta), + 'meta': from_json(kb_row.meta), 'algo_id': algo_row.algo_id if algo_row is not None else None, 'created_at': kb_row.created_at, 'updated_at': kb_row.updated_at, } def _validate_kb_algorithm(self, kb_id: str, algo_id: str): + if kb_id == '__default__': + self._ensure_default_kb() kb = self._get_kb(kb_id) if kb is None: raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) @@ -401,7 +225,7 @@ def _ensure_algorithm_exists(self, algo_id: str): raise DocServiceError('E_INVALID_PARAM', f'invalid algo_id: {algo_id}', {'algo_id': algo_id}) def _ensure_kb_document(self, kb_id: str, doc_id: str): - now = now_ts() + now = datetime.now() created = False with self._db_manager.get_session() as session: Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) @@ -438,7 +262,7 @@ def _load_idempotency_record(self, endpoint: str, idempotency_key: str): return _orm_to_dict(row) if row else None def _cleanup_idempotency_records(self, ttl_days: int = 7): - cutoff = now_ts() - timedelta(days=ttl_days) + cutoff = datetime.now() - timedelta(days=ttl_days) with self._db_manager.get_session() as session: Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) session.query(Record).filter(Record.updated_at < cutoff).delete() @@ -446,7 +270,7 @@ def _cleanup_idempotency_records(self, ttl_days: int = 7): def _claim_idempotency_key(self, endpoint: str, idempotency_key: str, req_hash: str): with self._db_manager.get_session() as session: Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) - now = now_ts() + now = datetime.now() row = Record( endpoint=endpoint, idempotency_key=idempotency_key, @@ -470,8 +294,8 @@ def _complete_idempotency_record(self, endpoint: str, idempotency_key: str, resp if row is None: return row.status = 'COMPLETED' - row.response_json = _stable_json(response) - row.updated_at = now_ts() + row.response_json = stable_json(response) + row.updated_at = datetime.now() session.add(row) def _drop_idempotency_claim(self, endpoint: str, idempotency_key: str): @@ -487,7 +311,7 @@ def _drop_idempotency_claim(self, endpoint: str, idempotency_key: str): def run_idempotent(self, endpoint: str, idempotency_key: Optional[str], payload: Any, handler): if not idempotency_key: return handler() - req_hash = _hash_payload(payload) + req_hash = hash_payload(payload) try: self._claim_idempotency_key(endpoint, idempotency_key, req_hash) except IntegrityError: @@ -513,7 +337,7 @@ def run_idempotent(self, endpoint: str, idempotency_key: Optional[str], payload: def _record_callback(self, callback_id: str, task_id: str): with self._db_manager.get_session() as session: Record = self._db_manager.get_table_orm_class(CALLBACK_RECORDS_TABLE_INFO['name']) - session.add(Record(callback_id=callback_id, task_id=task_id, created_at=now_ts())) + session.add(Record(callback_id=callback_id, task_id=task_id, created_at=datetime.now())) try: session.flush() return True @@ -531,7 +355,7 @@ def _forget_callback_record(self, callback_id: str, task_id: str): def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb_id: str, algo_id: str, status: DocStatus, message: Optional[Dict[str, Any]] = None): - now = now_ts() + now = datetime.now() with self._db_manager.get_session() as session: Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) session.add(Task( @@ -541,7 +365,7 @@ def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb kb_id=kb_id, algo_id=algo_id, status=status.value, - message=_to_json(message), + message=to_json(message), error_code=None, error_msg=None, created_at=now, @@ -558,7 +382,7 @@ def _get_task_record(self, task_id: str): if row is None: return None task = _orm_to_dict(row) - task['message'] = _from_json(task.get('message')) + task['message'] = from_json(task.get('message')) return task def _update_task_record(self, task_id: str, **fields): @@ -569,7 +393,7 @@ def _update_task_record(self, task_id: str, **fields): return None for key, value in fields.items(): setattr(row, key, value) - row.updated_at = now_ts() + row.updated_at = datetime.now() session.add(row) return self._get_task_record(task_id) @@ -583,7 +407,7 @@ def _refresh_kb_doc_count(self, kb_id: str): kb_row.doc_count = session.query(Rel).filter(Rel.kb_id == kb_id).count() if kb_row.status == KBStatus.DELETING.value and kb_row.doc_count == 0: kb_row.status = KBStatus.DELETED.value - kb_row.updated_at = now_ts() + kb_row.updated_at = datetime.now() session.add(kb_row) def _list_kb_doc_ids(self, kb_id: str) -> List[str]: @@ -624,10 +448,10 @@ def _upsert_doc( upload_status: DocStatus = DocStatus.SUCCESS, allowed_path_doc_ids: Optional[Set[str]] = None, ): - now = now_ts() + now = datetime.now() file_type = os.path.splitext(path)[1].lstrip('.').lower() or None size_bytes = os.path.getsize(path) if os.path.exists(path) else None - content_hash = _sha256_file(path) if os.path.exists(path) else None + content_hash = sha256_file(path) if os.path.exists(path) else None allowed_path_doc_ids = allowed_path_doc_ids or set() try: with self._db_manager.get_session() as session: @@ -645,7 +469,7 @@ def _upsert_doc( doc_id=doc_id, filename=filename, path=path, - meta=_to_json(metadata), + meta=to_json(metadata), upload_status=upload_status.value, source_type=source_type.value, file_type=file_type, @@ -657,7 +481,7 @@ def _upsert_doc( else: row.filename = filename row.path = path - row.meta = _to_json(metadata) + row.meta = to_json(metadata) row.upload_status = upload_status.value row.source_type = source_type.value row.file_type = file_type @@ -687,7 +511,7 @@ def _set_doc_upload_status(self, doc_id: str, status: DocStatus): if row is None: return None row.upload_status = status.value - row.updated_at = now_ts() + row.updated_at = datetime.now() session.add(row) return self._get_doc(doc_id) @@ -743,7 +567,7 @@ def _mark_task_cleanup_policy(self, task_id: str, cleanup_policy: str): if message.get('cleanup_policy') == cleanup_policy: return message['cleanup_policy'] = cleanup_policy - self._update_task_record(task_id, message=_to_json(message)) + self._update_task_record(task_id, message=to_json(message)) def _finalize_kb_deletion_if_empty(self, kb_id: str) -> bool: with self._db_manager.get_session() as session: @@ -839,7 +663,7 @@ def _upsert_parse_snapshot( started_at: Optional[datetime] = None, finished_at: Optional[datetime] = None, ): - now = now_ts() + now = datetime.now() with self._db_manager.get_session() as session: State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) row = ( @@ -993,7 +817,7 @@ def _enqueue_task( current_task_id=task_id, idempotency_key=idempotency_key, priority=priority, - queued_at=now_ts(), + queued_at=datetime.now(), started_at=None, finished_at=None, error_code=None, @@ -1007,7 +831,7 @@ def _enqueue_task( parser_kb_id=parser_kb_id, transfer_params=transfer_params, ) except Exception as exc: - finished_at = now_ts() + finished_at = datetime.now() error_msg = str(exc) self._update_task_record( task_id, @@ -1082,7 +906,7 @@ def _prepare_reparse_items(self, request: ReparseRequest) -> List[Dict[str, Any] prepared_items.append({ 'doc_id': doc_id, 'file_path': doc.get('path'), - 'metadata': _from_json(doc.get('meta')), + 'metadata': from_json(doc.get('meta')), }) return prepared_items @@ -1126,7 +950,7 @@ def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[D if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') self._assert_action_allowed(item.doc_id, request.kb_id, request.algo_id, 'metadata') - merged = _from_json(doc.get('meta')) + merged = from_json(doc.get('meta')) merged.update(item.patch) prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) return prepared_items @@ -1194,13 +1018,15 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An 'status': source_snapshot.get('status') if source_snapshot else None, }, ) - source_metadata = _from_json(doc.get('meta')) - target_metadata = _merge_transfer_metadata(source_metadata, item.target_metadata) - target_path = _resolve_transfer_target_path(doc.get('path'), item.target_filename, item.target_file_path) + source_metadata = from_json(doc.get('meta')) + target_metadata = merge_transfer_metadata(source_metadata, item.target_metadata) + target_path = resolve_transfer_target_path(doc.get('path'), item.target_filename, item.target_file_path) target_filename = os.path.basename(target_path) prepared_items.append({ 'doc_id': item.doc_id, 'target_doc_id': item.target_doc_id, + 'kb_id': item.kb_id, + 'algo_id': item.algo_id, 'source_kb_id': item.source_kb_id, 'source_algo_id': item.source_algo_id, 'target_kb_id': item.target_kb_id, @@ -1223,6 +1049,7 @@ def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, An def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: self._validate_kb_algorithm(request.kb_id, request.algo_id) prepared_items = self._prepare_upload_items(request) + source_type = request.source_type or SourceType.API items: List[Dict[str, Any]] = [] for item in prepared_items: doc_id = item['doc_id'] @@ -1233,7 +1060,7 @@ def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: filename=item['filename'], path=file_path, metadata=metadata, - source_type=request.source_type, + source_type=source_type, upload_status=DocStatus.SUCCESS, ) self._ensure_kb_document(request.kb_id, doc_id) @@ -1272,7 +1099,7 @@ def add_files(self, request: AddRequest) -> List[Dict[str, Any]]: items=request.items, kb_id=request.kb_id, algo_id=request.algo_id, - source_type=request.source_type, + source_type=request.source_type or SourceType.EXTERNAL, idempotency_key=request.idempotency_key, )) @@ -1339,7 +1166,7 @@ def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: row = session.query(Doc).filter(Doc.doc_id == doc_id).first() if self._doc_relation_count(doc_id) <= 1: row.upload_status = DocStatus.DELETING.value - row.updated_at = now_ts() + row.updated_at = datetime.now() session.add(row) task_id, snapshot = self._enqueue_task( @@ -1408,6 +1235,8 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: 'doc_id': item['doc_id'], 'target_doc_id': item['target_doc_id'], 'task_id': task_id, + 'kb_id': item['kb_id'], + 'algo_id': item['algo_id'], 'source_kb_id': item['source_kb_id'], 'target_kb_id': item['target_kb_id'], 'source_algo_id': item['source_algo_id'], @@ -1431,8 +1260,8 @@ def patch_metadata(self, request: MetadataPatchRequest): with self._db_manager.get_session() as session: Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) row = session.query(Doc).filter(Doc.doc_id == item['doc_id']).first() - row.meta = _to_json(item['metadata']) - row.updated_at = now_ts() + row.meta = to_json(item['metadata']) + row.updated_at = datetime.now() session.add(row) task_id, _ = self._enqueue_task( @@ -1513,7 +1342,7 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 self._update_task_record( callback.task_id, status=DocStatus.WORKING.value, - started_at=now_ts(), + started_at=datetime.now(), finished_at=None, error_code=None, error_msg=None, @@ -1528,7 +1357,7 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 snapshot, task_type=task_type, current_task_id=callback.task_id, - started_at=now_ts(), + started_at=datetime.now(), finished_at=None, error_code=None, error_msg=None, @@ -1567,7 +1396,7 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 status=final_status.value, error_code=callback.error_code, error_msg=callback.error_msg, - finished_at=now_ts(), + finished_at=datetime.now(), ) self._upsert_parse_snapshot( @@ -1582,7 +1411,7 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 error_code=callback.error_code, error_msg=callback.error_msg, failed_stage=failed_stage, - finished_at=now_ts(), + finished_at=datetime.now(), ), ) @@ -1635,7 +1464,7 @@ def list_docs( ) if status and (snapshot is None or snapshot.get('status') not in status): continue - doc['metadata'] = _from_json(doc.get('meta')) + doc['metadata'] = from_json(doc.get('meta')) items.append({'doc': doc, 'relation': relation, 'snapshot': snapshot}) total = len(items) page_items = items[(page - 1) * page_size:page * page_size] @@ -1646,7 +1475,7 @@ def get_doc_detail(self, doc_id: str): if not doc: raise DocServiceError('E_NOT_FOUND', f'doc not found: {doc_id}', {'doc_id': doc_id}) - doc['metadata'] = _from_json(doc.get('meta')) + doc['metadata'] = from_json(doc.get('meta')) snapshots = self.list_docs(page=1, page_size=2000)['items'] matched_items = [] for item in snapshots: @@ -1671,8 +1500,8 @@ def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): if callable(parser_list_tasks): try: return parser_list_tasks(status=status, page=page, page_size=page_size) - except Exception: - pass + except Exception as exc: + LOG.warning(f'[DocService] Fallback to local task list: {exc}') page = max(page, 1) page_size = max(page_size, 1) with self._db_manager.get_session() as session: @@ -1690,7 +1519,7 @@ def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): items = [] for row in rows: task = _orm_to_dict(row) - task['message'] = _from_json(task.get('message')) + task['message'] = from_json(task.get('message')) items.append(task) return BaseResponse(code=200, msg='success', data={ 'items': items, @@ -1704,8 +1533,8 @@ def get_task(self, task_id: str): if callable(parser_get_task): try: return parser_get_task(task_id) - except Exception: - pass + except Exception as exc: + LOG.warning(f'[DocService] Fallback to local task query: {exc}') task = self._get_task_record(task_id) if task is None: return BaseResponse(code=404, msg='task not found', data=None) @@ -1830,12 +1659,17 @@ def list_chunks( return data def health(self): + parser_ok = False + try: + parser_ok = self._parser_client.health().code == 200 + except Exception as exc: + LOG.warning(f'[DocService] Parser health check failed: {exc}') return { 'status': 'ok', 'version': 'v1', 'deps': { 'sql': True, - 'parser': bool(getattr(self._parser_client, '_parser_url', None)), + 'parser': parser_ok, }, } @@ -1983,7 +1817,7 @@ def delete_kb(self, kb_id: str): kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() if kb_row is not None: kb_row.status = new_status - kb_row.updated_at = now_ts() + kb_row.updated_at = datetime.now() session.add(kb_row) else: self._finalize_kb_deletion_if_empty(kb_id) diff --git a/lazyllm/tools/rag/doc_service/doc_server.openapi.json b/lazyllm/tools/rag/doc_service/doc_server.openapi.json new file mode 100644 index 000000000..0577f9e01 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_server.openapi.json @@ -0,0 +1,2078 @@ +{ + "components": { + "schemas": { + "AddFileItem": { + "properties": { + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "file_path": { + "title": "File Path", + "type": "string" + }, + "metadata": { + "additionalProperties": true, + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "file_path" + ], + "title": "AddFileItem", + "type": "object" + }, + "AlgorithmInfoRequest": { + "properties": { + "algo_id": { + "title": "Algo Id", + "type": "string" + } + }, + "required": [ + "algo_id" + ], + "title": "AlgorithmInfoRequest", + "type": "object" + }, + "Body_upload_v1_docs_upload_post": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "files": { + "items": { + "contentMediaType": "application/octet-stream", + "type": "string" + }, + "title": "Files", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + }, + "source_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/SourceType" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "files" + ], + "title": "Body_upload_v1_docs_upload_post", + "type": "object" + }, + "CallbackEventType": { + "enum": [ + "START", + "FINISH" + ], + "title": "CallbackEventType", + "type": "string" + }, + "DeleteRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_ids": { + "items": { + "type": "string" + }, + "title": "Doc Ids", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "doc_ids" + ], + "title": "DeleteRequest", + "type": "object" + }, + "DocItemsRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/AddFileItem" + }, + "title": "Items", + "type": "array" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + }, + "source_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/SourceType" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "items" + ], + "title": "DocItemsRequest", + "type": "object" + }, + "DocStatus": { + "enum": [ + "WAITING", + "WORKING", + "SUCCESS", + "FAILED", + "CANCELED", + "DELETING", + "DELETED" + ], + "title": "DocStatus", + "type": "string" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "KbBatchQueryRequest": { + "properties": { + "kb_ids": { + "items": { + "type": "string" + }, + "title": "Kb Ids", + "type": "array" + } + }, + "required": [ + "kb_ids" + ], + "title": "KbBatchQueryRequest", + "type": "object" + }, + "KbCreateRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Display Name" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "title": "Kb Id", + "type": "string" + }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Meta" + }, + "owner_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + }, + "required": [ + "kb_id" + ], + "title": "KbCreateRequest", + "type": "object" + }, + "KbDeleteBatchRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_ids": { + "items": { + "type": "string" + }, + "title": "Kb Ids", + "type": "array" + } + }, + "required": [ + "kb_ids" + ], + "title": "KbDeleteBatchRequest", + "type": "object" + }, + "KbUpdateRequest": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Display Name" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Meta" + }, + "owner_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + }, + "title": "KbUpdateRequest", + "type": "object" + }, + "MetadataPatchItem": { + "properties": { + "doc_id": { + "title": "Doc Id", + "type": "string" + }, + "patch": { + "additionalProperties": true, + "title": "Patch", + "type": "object" + } + }, + "required": [ + "doc_id" + ], + "title": "MetadataPatchItem", + "type": "object" + }, + "MetadataPatchRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/MetadataPatchItem" + }, + "title": "Items", + "type": "array" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "items" + ], + "title": "MetadataPatchRequest", + "type": "object" + }, + "ReparseRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_ids": { + "items": { + "type": "string" + }, + "title": "Doc Ids", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "doc_ids" + ], + "title": "ReparseRequest", + "type": "object" + }, + "SourceType": { + "enum": [ + "API", + "SCAN", + "TEMP", + "EXTERNAL" + ], + "title": "SourceType", + "type": "string" + }, + "TaskBatchRequest": { + "properties": { + "task_ids": { + "items": { + "type": "string" + }, + "title": "Task Ids", + "type": "array" + } + }, + "required": [ + "task_ids" + ], + "title": "TaskBatchRequest", + "type": "object" + }, + "TaskCallbackPayload": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "callback_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Callback Id" + }, + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "error_code": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Code" + }, + "error_msg": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Msg" + }, + "event_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/CallbackEventType" + }, + { + "type": "null" + } + ] + }, + "kb_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + }, + "payload": { + "additionalProperties": true, + "title": "Payload", + "type": "object" + }, + "status": { + "anyOf": [ + { + "$ref": "#/components/schemas/DocStatus" + }, + { + "type": "null" + } + ] + }, + "task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Task Id" + }, + "task_status": { + "anyOf": [ + { + "$ref": "#/components/schemas/DocStatus" + }, + { + "type": "null" + } + ] + }, + "task_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/TaskType" + }, + { + "type": "null" + } + ] + } + }, + "title": "TaskCallbackPayload", + "type": "object" + }, + "TaskCancelRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "task_id": { + "title": "Task Id", + "type": "string" + } + }, + "required": [ + "task_id" + ], + "title": "TaskCancelRequest", + "type": "object" + }, + "TaskInfoRequest": { + "properties": { + "task_id": { + "title": "Task Id", + "type": "string" + } + }, + "required": [ + "task_id" + ], + "title": "TaskInfoRequest", + "type": "object" + }, + "TaskType": { + "enum": [ + "DOC_ADD", + "DOC_DELETE", + "DOC_UPDATE_META", + "DOC_REPARSE", + "DOC_TRANSFER" + ], + "title": "TaskType", + "type": "string" + }, + "TransferItem": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_id": { + "title": "Doc Id", + "type": "string" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + }, + "mode": { + "default": "copy", + "title": "Mode", + "type": "string" + }, + "target_algo_id": { + "title": "Target Algo Id", + "type": "string" + }, + "target_doc_id": { + "title": "Target Doc Id", + "type": "string" + }, + "target_file_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Target File Path" + }, + "target_filename": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Target Filename" + }, + "target_kb_id": { + "title": "Target Kb Id", + "type": "string" + }, + "target_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Target Metadata" + } + }, + "required": [ + "doc_id", + "target_doc_id", + "target_kb_id", + "target_algo_id" + ], + "title": "TransferItem", + "type": "object" + }, + "TransferRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/TransferItem" + }, + "title": "Items", + "type": "array" + } + }, + "required": [ + "items" + ], + "title": "TransferRequest", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + } + } + }, + "info": { + "description": "OpenAPI schema generated from current DocServer routes.", + "title": "LazyLLM DocService API", + "version": "1.0.0" + }, + "openapi": "3.1.0", + "paths": { + "/v1/algo/list": { + "get": { + "operationId": "list_algo_v1_algo_list_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "List Algo" + } + }, + "/v1/algo/{algo_id}/groups": { + "get": { + "operationId": "get_algo_groups_v1_algo__algo_id__groups_get", + "parameters": [ + { + "in": "path", + "name": "algo_id", + "required": true, + "schema": { + "title": "Algo Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Algo Groups" + } + }, + "/v1/algorithms": { + "get": { + "operationId": "list_algorithms_v1_algorithms_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "List Algorithms" + } + }, + "/v1/algorithms/info": { + "post": { + "operationId": "get_algorithm_info_v1_algorithms_info_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AlgorithmInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Algorithm Info" + } + }, + "/v1/chunks": { + "get": { + "operationId": "list_chunks_v1_chunks_get", + "parameters": [ + { + "in": "query", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + }, + { + "in": "query", + "name": "doc_id", + "required": true, + "schema": { + "title": "Doc Id", + "type": "string" + } + }, + { + "in": "query", + "name": "group", + "required": true, + "schema": { + "title": "Group", + "type": "string" + } + }, + { + "in": "query", + "name": "algo_id", + "required": false, + "schema": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + } + }, + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Offset" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Chunks" + } + }, + "/v1/docs": { + "get": { + "operationId": "list_docs_v1_docs_get", + "parameters": [ + { + "in": "query", + "name": "kb_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + } + }, + { + "in": "query", + "name": "algo_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + } + }, + { + "in": "query", + "name": "keyword", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Keyword" + } + }, + { + "in": "query", + "name": "include_deleted_or_canceled", + "required": false, + "schema": { + "default": true, + "title": "Include Deleted Or Canceled", + "type": "boolean" + } + }, + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Docs" + } + }, + "/v1/docs/add": { + "post": { + "operationId": "add_v1_docs_add_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocItemsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Add" + } + }, + "/v1/docs/delete": { + "post": { + "operationId": "delete_v1_docs_delete_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete" + } + }, + "/v1/docs/metadata/patch": { + "post": { + "operationId": "patch_metadata_v1_docs_metadata_patch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MetadataPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Patch Metadata" + } + }, + "/v1/docs/reparse": { + "post": { + "operationId": "reparse_v1_docs_reparse_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ReparseRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Reparse" + } + }, + "/v1/docs/transfer": { + "post": { + "operationId": "transfer_v1_docs_transfer_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TransferRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Transfer" + } + }, + "/v1/docs/upload": { + "post": { + "operationId": "upload_v1_docs_upload_post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_v1_docs_upload_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Upload" + } + }, + "/v1/docs/{doc_id}": { + "get": { + "operationId": "get_doc_v1_docs__doc_id__get", + "parameters": [ + { + "in": "path", + "name": "doc_id", + "required": true, + "schema": { + "title": "Doc Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Doc" + } + }, + "/v1/health": { + "get": { + "operationId": "health_v1_health_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "Health" + } + }, + "/v1/internal/callbacks/tasks": { + "post": { + "operationId": "task_callback_http_v1_internal_callbacks_tasks_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskCallbackPayload" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Task Callback Http" + } + }, + "/v1/internal/parser-url": { + "get": { + "operationId": "get_parser_url_v1_internal_parser_url_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "Get Parser Url" + } + }, + "/v1/kbs": { + "delete": { + "operationId": "delete_kbs_v1_kbs_delete", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbDeleteBatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete Kbs" + }, + "get": { + "operationId": "list_kbs_v1_kbs_get", + "parameters": [ + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + }, + { + "in": "query", + "name": "keyword", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Keyword" + } + }, + { + "in": "query", + "name": "owner_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Kbs" + }, + "post": { + "operationId": "create_kb_v1_kbs_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbCreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Create Kb" + } + }, + "/v1/kbs/batch": { + "post": { + "operationId": "batch_get_kbs_v1_kbs_batch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbBatchQueryRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Batch Get Kbs" + } + }, + "/v1/kbs/{kb_id}": { + "delete": { + "operationId": "delete_kb_v1_kbs__kb_id__delete", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + }, + { + "in": "query", + "name": "idempotency_key", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete Kb" + }, + "get": { + "operationId": "get_kb_v1_kbs__kb_id__get", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Kb" + } + }, + "/v1/kbs/{kb_id}/update": { + "post": { + "operationId": "update_kb_v1_kbs__kb_id__update_post", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbUpdateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Update Kb" + } + }, + "/v1/tasks": { + "get": { + "operationId": "list_tasks_v1_tasks_get", + "parameters": [ + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Tasks" + } + }, + "/v1/tasks/batch": { + "post": { + "operationId": "get_tasks_batch_v1_tasks_batch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskBatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Tasks Batch" + } + }, + "/v1/tasks/cancel": { + "post": { + "operationId": "cancel_task_v1_tasks_cancel_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskCancelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Cancel Task" + } + }, + "/v1/tasks/info": { + "post": { + "operationId": "get_task_info_v1_tasks_info_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Task Info" + } + }, + "/v1/tasks/{task_id}": { + "get": { + "operationId": "get_task_v1_tasks__task_id__get", + "parameters": [ + { + "in": "path", + "name": "task_id", + "required": true, + "schema": { + "title": "Task Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Task" + } + } + } +} \ No newline at end of file diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 40098851b..8d786f479 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -4,7 +4,7 @@ import json import os import traceback -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set import requests @@ -13,13 +13,32 @@ from ..utils import BaseResponse, _get_default_db_config, ensure_call_endpoint from .base import ( - AddRequest, DeleteRequest, DocServiceError, KbBatchQueryRequest, KbCreateRequest, KbUpdateRequest, - MetadataPatchRequest, ReparseRequest, + AddFileItem, + AddRequest, + AlgorithmInfoRequest, + CallbackEventType, + DeleteRequest, + DocServiceError, + DocStatus, + KbBatchQueryRequest, + KbCreateRequest, + KbDeleteBatchRequest, + KbUpdateRequest, + MetadataPatchRequest, + ReparseRequest, + SourceType, + TaskBatchRequest, + TaskCallbackPayload, + TaskCallbackRequest, + TaskCancelRequest, + TaskInfoRequest, + TransferRequest, + UploadRequest, ) -from .base import CallbackEventType, DocStatus, SourceType, TaskCallbackRequest -from .base import TransferRequest -from .base import UploadRequest, AddFileItem from .doc_manager import DocManager +from .utils import sha256_file + +DEFAULT_OPENAPI_OUTPUT_PATH = os.path.join(os.path.dirname(__file__), 'doc_server.openapi.json') class DocServer(ModuleBase): @@ -84,6 +103,7 @@ def _run(self, func, *args, success_msg='success', **kwargs): @staticmethod def _build_upload_payload(request: UploadRequest, file_identities: Optional[List[Dict[str, Any]]] = None): + source_type = request.source_type or SourceType.API items = file_identities if items is None: items = [] @@ -91,10 +111,8 @@ def _build_upload_payload(request: UploadRequest, file_identities: Optional[List content_hash = None size_bytes = None if os.path.exists(item.file_path): - with open(item.file_path, 'rb') as fh: - content = fh.read() - content_hash = hashlib.sha256(content).hexdigest() - size_bytes = len(content) + content_hash = sha256_file(item.file_path) + size_bytes = os.path.getsize(item.file_path) items.append({ 'filename': os.path.basename(item.file_path), 'content_hash': content_hash, @@ -104,7 +122,7 @@ def _build_upload_payload(request: UploadRequest, file_identities: Optional[List return { 'kb_id': request.kb_id, 'algo_id': request.algo_id, - 'source_type': request.source_type.value, + 'source_type': source_type.value, 'idempotency_key': request.idempotency_key, 'items': items, } @@ -132,6 +150,34 @@ def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = digest = hashlib.sha256(safe_name.encode()).hexdigest()[:8] return os.path.join(self._storage_dir, f'{prefix}-{digest}{suffix}') + @staticmethod + async def _save_upload_file(upload_file: 'fastapi.UploadFile', file_path: str): + with open(file_path, 'wb') as fh: + while True: + chunk = await upload_file.read(1024 * 1024) + if not chunk: + break + fh.write(chunk) + await upload_file.close() + + async def _persist_uploads(self, files: List['fastapi.UploadFile']): + saved_paths = [] + file_identities = [] + reserved_paths: Set[str] = set() + for upload_file in files: + filename = getattr(upload_file, 'filename', None) or 'upload.bin' + file_path = self._gen_unique_upload_path(filename, reserved_paths) + await self._save_upload_file(upload_file, file_path) + reserved_paths.add(file_path) + saved_paths.append(file_path) + file_identities.append({ + 'filename': os.path.basename(file_path), + 'content_hash': sha256_file(file_path), + 'size_bytes': os.path.getsize(file_path), + 'doc_id': None, + }) + return saved_paths, file_identities + def _run_upload(self, request: UploadRequest, payload: Optional[Dict[str, Any]] = None): idem_payload = payload or self._build_upload_payload(request) return self._run(lambda: self._manager.run_idempotent( @@ -205,60 +251,46 @@ def upload_request(self, request: UploadRequest): @app.post('/v1/docs/upload') async def upload( self, - request: 'fastapi.Request', - kb_id: str = '__default__', - algo_id: str = '__default__', - source_type: SourceType = SourceType.API, - doc_id: Optional[str] = None, - idempotency_key: Optional[str] = None, + files: List['fastapi.UploadFile'] = fastapi.File(...), # noqa: B008 + kb_id: Optional[str] = fastapi.Form(None), # noqa: B008 + algo_id: Optional[str] = fastapi.Form(None), # noqa: B008 + source_type: Optional[SourceType] = fastapi.Form(None), # noqa: B008 + doc_id: Optional[str] = fastapi.Form(None), # noqa: B008 + idempotency_key: Optional[str] = fastapi.Form(None), # noqa: B008 + kb_id_query: Optional[str] = fastapi.Query(None, alias='kb_id', include_in_schema=False), # noqa: B008 + algo_id_query: Optional[str] = fastapi.Query(None, alias='algo_id', include_in_schema=False), # noqa: B008 + source_type_query: Optional[SourceType] = fastapi.Query( # noqa: B008 + None, alias='source_type', include_in_schema=False + ), + doc_id_query: Optional[str] = fastapi.Query(None, alias='doc_id', include_in_schema=False), # noqa: B008 + idempotency_key_query: Optional[str] = fastapi.Query( # noqa: B008 + None, + alias='idempotency_key', + include_in_schema=False, + ), ): self._lazy_init() - form = await request.form() - files = form.getlist('files') if not files: raise fastapi.HTTPException(status_code=400, detail='files is required') - buffered_files = [] - file_identities = [] - for idx, file in enumerate(files): - filename = getattr(file, 'filename', None) or str(getattr(file, 'name', 'upload.bin')) - content = await file.read() if hasattr(file, 'read') else file.file.read() - buffered_files.append({'filename': filename, 'content': content}) - file_identities.append({ - 'filename': filename, - 'content_hash': hashlib.sha256(content).hexdigest(), - 'size_bytes': len(content), - 'doc_id': doc_id if idx == 0 else None, - }) - - def _handle_upload(): - saved_paths = [] - reserved_paths = set() - for item in buffered_files: - file_path = self._gen_unique_upload_path(item['filename'], reserved_paths) - with open(file_path, 'wb') as fh: - fh.write(item['content']) - saved_paths.append(file_path) - reserved_paths.add(file_path) - upload_request = UploadRequest( - items=[AddFileItem(file_path=path, doc_id=(doc_id if idx == 0 else None)) - for idx, path in enumerate(saved_paths)], - kb_id=kb_id, - algo_id=algo_id, - source_type=source_type, - idempotency_key=idempotency_key, - ) - return {'items': self._manager.upload(upload_request)} - - payload = { - 'kb_id': kb_id, - 'algo_id': algo_id, - 'source_type': source_type.value, - 'idempotency_key': idempotency_key, - 'items': file_identities, - } - return self._run(lambda: self._manager.run_idempotent( - '/v1/docs/upload', idempotency_key, payload, _handle_upload - )) + kb_id = kb_id or kb_id_query or '__default__' + algo_id = algo_id or algo_id_query or '__default__' + source_type = source_type or source_type_query or SourceType.API + doc_id = doc_id or doc_id_query + idempotency_key = idempotency_key or idempotency_key_query + saved_paths, file_identities = await self._persist_uploads(files) + upload_request = UploadRequest( + items=[ + AddFileItem(file_path=path, doc_id=(doc_id if idx == 0 else None)) + for idx, path in enumerate(saved_paths) + ], + kb_id=kb_id, + algo_id=algo_id, + source_type=source_type, + idempotency_key=idempotency_key, + ) + if file_identities: + file_identities[0]['doc_id'] = doc_id + return self._run_upload(upload_request, self._build_upload_payload(upload_request, file_identities)) @app.post('/v1/docs/add') def add(self, request: AddRequest): @@ -359,15 +391,12 @@ def cancel_task_by_id(self, task_id: str): return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) @app.post('/v1/tasks/cancel') - async def cancel_task(self, request: 'fastapi.Request'): - payload = await request.json() - task_id = payload.get('task_id') - if not task_id: - raise fastapi.HTTPException(status_code=400, detail='task_id is required') - idempotency_key = payload.get('idempotency_key') + def cancel_task(self, request: TaskCancelRequest): + self._lazy_init() + payload = request.model_dump(mode='json') def _cancel(): - resp = self._manager.cancel_task(task_id) + resp = self._manager.cancel_task(request.task_id) if resp.code == 404: raise DocServiceError('E_NOT_FOUND', resp.msg, resp.data) if resp.code == 409: @@ -376,7 +405,7 @@ def _cancel(): raise DocServiceError('E_INVALID_PARAM', resp.msg, resp.data) return resp.data return self._run(lambda: self._manager.run_idempotent( - '/v1/tasks/cancel', idempotency_key, payload, _cancel + '/v1/tasks/cancel', request.idempotency_key, payload, _cancel )) def task_callback(self, callback: Any): @@ -384,10 +413,8 @@ def task_callback(self, callback: Any): return self._run(lambda: self._manager.on_task_callback(self._normalize_task_callback(callback))) @app.post('/v1/internal/callbacks/tasks') - async def task_callback_http(self, request: 'fastapi.Request'): - self._lazy_init() - payload = await request.json() - return self.task_callback(payload) + def task_callback_http(self, request: TaskCallbackPayload): + return self.task_callback(request.model_dump(mode='json', exclude_none=True)) @app.get('/v1/algo/list') def list_algo(self): @@ -409,14 +436,9 @@ def list_algorithms_impl(self): return self._run(lambda: self._manager.list_algorithms_compat()) @app.post('/v1/algorithms/info') - async def get_algorithm_info(self, request: 'fastapi.Request'): + def get_algorithm_info(self, request: AlgorithmInfoRequest): self._lazy_init() - payload = await request.json() - algo_id = payload.get('algo_id') - if not algo_id: - return self._response(data={'biz_code': 'E_INVALID_PARAM'}, code=400, - msg='algo_id is required', status_code=400) - return self._run(lambda: self._manager.get_algorithm_info(algo_id)) + return self._run(lambda: self._manager.get_algorithm_info(request.algo_id)) def get_algorithm_info_impl(self, algo_id: str): self._lazy_init() @@ -445,25 +467,18 @@ def list_chunks( )) @app.post('/v1/tasks/batch') - async def get_tasks_batch(self, request: 'fastapi.Request'): + def get_tasks_batch(self, request: TaskBatchRequest): self._lazy_init() - payload = await request.json() - task_ids = payload.get('task_ids') or [] - return self._run(lambda: self._manager.get_tasks_batch(task_ids)) + return self._run(lambda: self._manager.get_tasks_batch(request.task_ids)) def get_tasks_batch_impl(self, task_ids: List[str]): self._lazy_init() return self._run(lambda: self._manager.get_tasks_batch(task_ids)) @app.post('/v1/tasks/info') - async def get_task_info(self, request: 'fastapi.Request'): + def get_task_info(self, request: TaskInfoRequest): self._lazy_init() - payload = await request.json() - task_id = payload.get('task_id') - if not task_id: - return self._response(data={'biz_code': 'E_INVALID_PARAM'}, code=400, - msg='task_id is required', status_code=400) - resp = self._manager.get_task(task_id) + resp = self._manager.get_task(request.task_id) return self._response( data=self._format_task_response_data(resp.data), code=resp.code, @@ -567,13 +582,11 @@ def delete_kb(self, kb_id: str, idempotency_key: Optional[str] = None): )) @app.delete('/v1/kbs') - async def delete_kbs(self, request: 'fastapi.Request'): + def delete_kbs(self, request: KbDeleteBatchRequest): self._lazy_init() - payload = await request.json() - kb_ids = payload.get('kb_ids') or [] - idempotency_key = payload.get('idempotency_key') + payload = request.model_dump(mode='json') return self._run(lambda: self._manager.run_idempotent( - '/v1/kbs:delete', idempotency_key, payload, lambda: self._manager.delete_kbs(kb_ids) + '/v1/kbs:delete', request.idempotency_key, payload, lambda: self._manager.delete_kbs(request.kb_ids) )) def delete_kbs_impl(self, kb_ids: List[str], idempotency_key: Optional[str] = None): @@ -657,6 +670,11 @@ def build_openapi_app(cls, title: str = 'LazyLLM DocService API', version: str = parser_url='http://127.0.0.1:9966', ) cls._register_openapi_routes(openapi_app, impl) + for route in openapi_app.routes: + body_field = getattr(route, 'body_field', None) + annotation = getattr(getattr(body_field, 'field_info', None), 'annotation', None) + if hasattr(annotation, 'model_rebuild'): + annotation.model_rebuild(force=True, _types_namespace=route.endpoint.__globals__) return openapi_app @classmethod @@ -664,9 +682,16 @@ def build_openapi_schema(cls, title: str = 'LazyLLM DocService API', version: st return cls.build_openapi_app(title=title, version=version).openapi() @classmethod - def export_openapi(cls, output_path: str, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + def export_openapi( + cls, + output_path: str = DEFAULT_OPENAPI_OUTPUT_PATH, + title: str = 'LazyLLM DocService API', + version: str = '1.0.0', + ): schema = cls.build_openapi_schema(title=title, version=version) - os.makedirs(os.path.dirname(output_path), exist_ok=True) + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as fh: json.dump(schema, fh, ensure_ascii=False, indent=2, sort_keys=True) return output_path diff --git a/lazyllm/tools/rag/doc_service/parser_client.py b/lazyllm/tools/rag/doc_service/parser_client.py new file mode 100644 index 000000000..1173c208b --- /dev/null +++ b/lazyllm/tools/rag/doc_service/parser_client.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import requests + +from ..parsing_service.base import ( + AddDocRequest as ParsingAddDocRequest, + CancelTaskRequest as ParsingCancelTaskRequest, + DeleteDocRequest as ParsingDeleteDocRequest, + FileInfo as ParsingFileInfo, + TransferParams as ParsingTransferParams, + UpdateMetaRequest as ParsingUpdateMetaRequest, +) +from ..utils import BaseResponse +from .utils import normalize_api_base_url + + +class ParserClient: + def __init__(self, parser_url: str): + self._parser_url = normalize_api_base_url(parser_url) + + def _request(self, method: str, path: str, **kwargs): + response = requests.request(method, f'{self._parser_url}{path}', timeout=8, **kwargs) + if response.status_code >= 400: + raise RuntimeError(f'parser http error: {response.status_code} {response.text}') + return response.json() + + def _get_with_fallback(self, paths: List[str], params: Optional[Dict[str, Any]] = None): + last_error = None + for path in paths: + try: + return self._request('GET', path, params=params) + except RuntimeError as exc: + last_error = exc + if '404' not in str(exc): + raise + if last_error is not None: + raise last_error + raise RuntimeError('parser http error: no path provided') + + def health(self): + return BaseResponse.model_validate(self._request('GET', '/health')) + + def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path: str, + metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, + callback_url: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None): + req = ParsingAddDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo( + file_path=file_path, + doc_id=doc_id, + metadata=metadata or {}, + reparse_group=reparse_group, + transfer_params=( + ParsingTransferParams.model_validate(transfer_params) + if transfer_params is not None else None + ), + )], + ) + return BaseResponse.model_validate(self._request('POST', '/doc/add', json=req.model_dump(mode='json'))) + + def update_meta(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + metadata: Optional[Dict[str, Any]] = None, file_path: Optional[str] = None, + callback_url: Optional[str] = None): + req = ParsingUpdateMetaRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo(file_path=file_path, doc_id=doc_id, metadata=metadata or {})], + ) + return BaseResponse.model_validate( + self._request('POST', '/doc/meta/update', json=req.model_dump(mode='json')) + ) + + def delete_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + callback_url: Optional[str] = None): + req = ParsingDeleteDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + doc_ids=[doc_id], + callback_url=callback_url, + feedback_url=callback_url, + ) + return BaseResponse.model_validate(self._request('DELETE', '/doc/delete', json=req.model_dump(mode='json'))) + + def cancel_task(self, task_id: str): + req = ParsingCancelTaskRequest(task_id=task_id) + return BaseResponse.model_validate(self._request('POST', '/doc/cancel', json=req.model_dump(mode='json'))) + + def list_algorithms(self): + return BaseResponse.model_validate(self._get_with_fallback(['/v1/algo/list', '/algo/list'])) + + def get_algorithm_groups(self, algo_id: str): + try: + data = self._get_with_fallback([f'/v1/algo/{algo_id}/groups', f'/algo/{algo_id}/group/info']) + return BaseResponse.model_validate(data) + except RuntimeError as exc: + if '404' in str(exc): + return BaseResponse(code=404, msg='algo not found', data=None) + raise + + def list_doc_chunks(self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int, page_size: int): + data = self._request('GET', '/doc/chunks', params={ + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'group': group, + 'offset': offset, + 'page_size': page_size, + }) + return BaseResponse.model_validate(data) diff --git a/lazyllm/tools/rag/doc_service/utils.py b/lazyllm/tools/rag/doc_service/utils.py new file mode 100644 index 000000000..9a15d35cb --- /dev/null +++ b/lazyllm/tools/rag/doc_service/utils.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import hashlib +import json +import os +from typing import Any, Dict, Optional + +from lazyllm import LOG +from ..utils import gen_docid + + +def to_json(data: Optional[Dict[str, Any]]) -> str: + return json.dumps(data or {}, ensure_ascii=False) + + +def from_json(raw: Optional[str]) -> Dict[str, Any]: + if not raw: + return {} + try: + return json.loads(raw) + except Exception: + LOG.warning('[DocService] Failed to decode json payload') + return {} + + +def gen_doc_id(file_path: str, doc_id: Optional[str] = None) -> str: + return doc_id or gen_docid(file_path) + + +def stable_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, sort_keys=True, default=str) + + +def hash_payload(data: Any) -> str: + return hashlib.sha256(stable_json(data).encode()).hexdigest() + + +def sha256_file(file_path: str) -> str: + digest = hashlib.sha256() + with open(file_path, 'rb') as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b''): + digest.update(chunk) + return digest.hexdigest() + + +def merge_transfer_metadata( + source_metadata: Dict[str, Any], target_metadata: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + metadata = dict(source_metadata or {}) + if target_metadata: + metadata.update(target_metadata) + return metadata + + +def resolve_transfer_target_path( + source_path: str, target_filename: Optional[str], target_file_path: Optional[str] +) -> str: + if target_file_path: + return target_file_path + if target_filename: + base_dir = os.path.dirname(source_path) if source_path else '' + return os.path.join(base_dir, target_filename) if base_dir else target_filename + return source_path + + +def normalize_api_base_url(url: str) -> str: + url = url.rstrip('/') + if url.endswith('/_call') or url.endswith('/generate'): + return url.rsplit('/', 1)[0] + return url diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index 690786098..19b4a65d3 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -207,9 +207,12 @@ def _resolve_add_doc_task_type(request: AddDocRequest) -> str: # noqa: C901 } # Finished task queue table +# NOTE: callback-related columns were appended after the initial queue schema. Existing deployments may still +# use an older table layout, but queue initialization already auto-adds any missing nullable columns in place +# via ``_SQLBasedQueue._ensure_columns_exist()``, so startup remains backward compatible without extra migration code. FINISHED_TASK_QUEUE_TABLE_INFO = { 'name': 'lazyllm_finished_task_queue', - 'comment': 'Finished task queue table', + 'comment': 'Finished task queue table; legacy tables are extended in place with new columns at startup', 'columns': [ {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, 'comment': 'Auto increment ID'}, diff --git a/lazyllm/tools/rag/parsing_service/queue.py b/lazyllm/tools/rag/parsing_service/queue.py index d0a04e790..5293c6323 100644 --- a/lazyllm/tools/rag/parsing_service/queue.py +++ b/lazyllm/tools/rag/parsing_service/queue.py @@ -38,6 +38,8 @@ def __init__(self, table_name: str, columns: List[Dict[str, Any]], db_config: Di raise def _ensure_columns_exist(self): + # Keep queue tables backward compatible by adding newly introduced columns + # to existing tables instead of requiring a manual migration step. inspector = sqlalchemy.inspect(self._sql_manager.engine) existing_columns = {column['name'] for column in inspector.get_columns(self._table_name)} missing_columns = [column for column in self._columns if column['name'] not in existing_columns] diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index 43ecd853e..ec5a1b97d 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -1,13 +1,13 @@ +import inspect import json import random -import inspect import threading import time import traceback from datetime import datetime, timedelta from email.utils import parsedate_to_datetime +from typing import Any, Callable, Dict, List, Optional, Tuple from uuid import NAMESPACE_URL, uuid5 -from typing import Any, Callable, Dict, Optional, Tuple, List from lazyllm import ( LOG, ModuleBase, ServerModule, UrlModule, FastapiApp as app, @@ -38,7 +38,7 @@ class DocumentProcessor(ModuleBase): - class _Impl(): + class _Impl: def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int = 1, post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, callback_url: Optional[str] = None, @@ -496,13 +496,7 @@ def _format_chunk_item(segment: Dict[str, Any]) -> Dict[str, Any]: } def _list_doc_chunks_data( - self, - algo_id: str, - kb_id: str, - doc_id: str, - group: str, - offset: int = 0, - limit: int = 20, + self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int = 0, limit: int = 20 ) -> Dict[str, Any]: algorithm = self._get_algo(algo_id) if algorithm is None: diff --git a/tests/basic_tests/RAG/test_doc_service_doc_server.py b/tests/basic_tests/RAG/test_doc_service_doc_server.py index 5297bb463..ee10740e8 100644 --- a/tests/basic_tests/RAG/test_doc_service_doc_server.py +++ b/tests/basic_tests/RAG/test_doc_service_doc_server.py @@ -5,6 +5,7 @@ import pytest import requests +from pydantic import ValidationError from lazyllm.thirdparty import fastapi @@ -17,43 +18,30 @@ KbUpdateRequest, SourceType, TaskCallbackRequest, + TaskCancelRequest, UploadRequest, ) from lazyllm.tools.rag.utils import BaseResponse -class _JsonRequest: - def __init__(self, payload): - self._payload = payload - - async def json(self): - return self._payload - - -class _FormData: - def __init__(self, files): - self._files = files - - def getlist(self, name): - assert name == 'files' - return self._files - - -class _FormRequest: - def __init__(self, files): - self._form = _FormData(files) - - async def form(self): - return self._form - - class _UploadFile: def __init__(self, filename: str, content: bytes): self.filename = filename self._content = content + self._offset = 0 + + async def read(self, size: int = -1): + if self._offset >= len(self._content): + return b'' + if size is None or size < 0: + size = len(self._content) - self._offset + start = self._offset + end = min(len(self._content), start + size) + self._offset = end + return self._content[start:end] - async def read(self): - return self._content + async def close(self): + return None class _FakeManager: @@ -167,11 +155,9 @@ def _raise(*args, **kwargs): assert server.parser_url is None -def test_cancel_task_http_requires_task_id(server_impl): - with pytest.raises(fastapi.HTTPException) as exc: - asyncio.run(server_impl.cancel_task(_JsonRequest({}))) - assert exc.value.status_code == 400 - assert exc.value.detail == 'task_id is required' +def test_task_cancel_request_requires_task_id(): + with pytest.raises(ValidationError): + TaskCancelRequest.model_validate({}) def test_cancel_task_http_maps_conflict(server_impl): @@ -181,7 +167,7 @@ def test_cancel_task_http_maps_conflict(server_impl): data={'task_id': 'task-1', 'cancel_status': False, 'status': 'WORKING'}, ) - response = asyncio.run(server_impl.cancel_task(_JsonRequest({'task_id': 'task-1', 'idempotency_key': 'idem'}))) + response = server_impl.cancel_task(TaskCancelRequest(task_id='task-1', idempotency_key='idem')) body = _decode_response(response) assert server_impl._manager.run_calls[0]['endpoint'] == '/v1/tasks/cancel' @@ -243,7 +229,7 @@ def test_upload_http_saves_unique_files_and_only_first_doc_id(server_impl): ] response = asyncio.run(server_impl.upload( - _FormRequest(files), + files=files, kb_id='kb-upload', algo_id='__default__', source_type=SourceType.API, From baefa882fbde090cfe5e43414d90346c4a4a4fee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 8 Apr 2026 14:26:57 +0800 Subject: [PATCH 42/46] Tighten doc manager formatting --- lazyllm/tools/rag/doc_service/doc_manager.py | 35 +++++++++++--------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 05f9603b1..8ba175d52 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -15,27 +15,27 @@ from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict from ...sql import SqlManager from .base import ( - AddRequest, - CallbackEventType, CALLBACK_RECORDS_TABLE_INFO, + DocServiceError, + DocStatus, + KBStatus, + SourceType, + TaskType, + AddRequest, DeleteRequest, + MetadataPatchRequest, + ReparseRequest, + TaskCallbackRequest, + TransferRequest, + UploadRequest, DOC_SERVICE_TASKS_TABLE_INFO, - DocServiceError, DOCUMENTS_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, - KBStatus, KBS_TABLE_INFO, - MetadataPatchRequest, PARSE_STATE_TABLE_INFO, - ReparseRequest, - SourceType, - TaskCallbackRequest, - TaskType, - TransferRequest, - UploadRequest, - DocStatus, + CallbackEventType, ) from .parser_client import ParserClient from .utils import ( @@ -63,10 +63,13 @@ def __init__( self._db_config = db_config or _get_default_db_config('doc_service') self._db_manager = SqlManager( **self._db_config, - tables_info_dict={'tables': [DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, - KB_ALGORITHM_TABLE_INFO, PARSE_STATE_TABLE_INFO, - IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO, - DOC_SERVICE_TASKS_TABLE_INFO]}, + tables_info_dict={ + 'tables': [ + DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, + PARSE_STATE_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO, + DOC_SERVICE_TASKS_TABLE_INFO, + ] + }, ) self._ensure_indexes() self._parser_client = ParserClient(parser_url=parser_url) From 50dad377a89b0a875b25fb568872243c2b2a222f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 8 Apr 2026 16:29:21 +0800 Subject: [PATCH 43/46] Finish remaining doc_service review follow-ups --- lazyllm/components/deploy/relay/server.py | 27 +- lazyllm/tools/rag/doc_impl.py | 18 +- lazyllm/tools/rag/doc_service/base.py | 13 +- lazyllm/tools/rag/doc_service/doc_manager.py | 58 +- lazyllm/tools/rag/doc_service/doc_server.py | 42 +- lazyllm/tools/rag/document.py | 118 ++- lazyllm/tools/rag/parsing_service/server.py | 17 +- lazyllm/tools/rag/parsing_service/worker.py | 16 +- lazyllm/tools/rag/utils.py | 4 +- tests/basic_tests/RAG/test_doc_manager.py | 294 ++++++ tests/basic_tests/RAG/test_doc_processor.py | 46 +- .../RAG/test_doc_service_doc_manager.py | 47 +- .../basic_tests/RAG/test_doc_service_mock.py | 994 ------------------ tests/basic_tests/RAG/test_document.py | 270 +---- 14 files changed, 545 insertions(+), 1419 deletions(-) create mode 100644 tests/basic_tests/RAG/test_doc_manager.py delete mode 100644 tests/basic_tests/RAG/test_doc_service_mock.py diff --git a/lazyllm/components/deploy/relay/server.py b/lazyllm/components/deploy/relay/server.py index b98457ac7..51e713562 100644 --- a/lazyllm/components/deploy/relay/server.py +++ b/lazyllm/components/deploy/relay/server.py @@ -37,8 +37,7 @@ def _inject_pythonpath(argv): from lazyllm import FastapiApp, globals, kwargs, load_obj, package # noqa: E402 from lazyllm.common import _register_trim_module, _trim_traceback # noqa: E402 -from fastapi import FastAPI, Request # noqa: E402 -from fastapi.responses import Response, StreamingResponse # noqa: E402 +from lazyllm.thirdparty import fastapi # noqa: E402 import requests # noqa: E402 # TODO(sunxiaoye): delete in the future @@ -80,7 +79,7 @@ def _inject_pythonpath(argv): _err_msg += (f' defined at `{load_obj(args.defined_pos)}`' if args.defined_pos else '') + ':\n' -app = FastAPI() +app = fastapi.FastAPI() FastapiApp.update() async def async_wrapper(func, *args, **kwargs): @@ -95,30 +94,30 @@ def impl(func, sid, global_data, *args, **kw): def security_check(f: Callable): @functools.wraps(f) - async def wrapper(request: Request): + async def wrapper(request: fastapi.Request): if args.security_key and args.security_key != request.headers.get('Security-Key'): - return Response(content='Authentication failed', status_code=401) + return fastapi.responses.Response(content='Authentication failed', status_code=401) return (await f(request)) if inspect.iscoroutinefunction(f) else f(request) return wrapper @app.post('/_call') @security_check -async def lazyllm_call(request: Request): +async def lazyllm_call(request: fastapi.Request): try: fname, args, kwargs = await request.json() args, kwargs = load_obj(args), load_obj(kwargs) r = await async_wrapper(getattr(func, fname), *args, **kwargs) - return Response(content=codecs.encode(pickle.dumps(r), 'base64')) + return fastapi.responses.Response(content=codecs.encode(pickle.dumps(r), 'base64')) except requests.RequestException as e: - return Response(content=f'{str(e)}', status_code=500) + return fastapi.responses.Response(content=f'{str(e)}', status_code=500) except Exception: exc_type, exc_value, exc_tb = sys.exc_info() formatted = ''.join(traceback.format_exception(exc_type, exc_value, _trim_traceback(exc_tb))) - return Response(content=f'{_err_msg}\n{formatted}', status_code=500) + return fastapi.responses.Response(content=f'{_err_msg}\n{formatted}', status_code=500) @app.post('/generate') @security_check -async def generate(request: Request): # noqa C901 +async def generate(request: fastapi.Request): # noqa C901 try: input, kw = (await request.json()), {} try: @@ -157,7 +156,7 @@ def impl(o): def generate_stream(): for o in output: yield impl(o) - return StreamingResponse(generate_stream(), media_type='text_plain') + return fastapi.responses.StreamingResponse(generate_stream(), media_type='text_plain') elif args.after_function: assert (callable(after_func)), 'after_func must be callable' r = inspect.getfullargspec(after_func) @@ -169,13 +168,13 @@ def generate_stream(): after_func(output, **{r.kwonlyargs[0]: origin}) elif len(new_args) == 2: output = after_func(output, origin) - return Response(content=impl(output)) + return fastapi.responses.Response(content=impl(output)) except requests.RequestException as e: - return Response(content=f'{str(e)}', status_code=500) + return fastapi.responses.Response(content=f'{str(e)}', status_code=500) except Exception: exc_type, exc_value, exc_tb = sys.exc_info() formatted = ''.join(traceback.format_exception(exc_type, exc_value, _trim_traceback(exc_tb))) - return Response(content=f'{_err_msg}\n{formatted}', status_code=500) + return fastapi.responses.Response(content=f'{_err_msg}\n{formatted}', status_code=500) finally: globals.clear() diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index c4e0dffa5..b7301f443 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -14,7 +14,7 @@ from .store.document_store import _DocumentStore from .doc_node import DocNode from .data_loaders import DirectoryReader -from .utils import gen_docid, is_sparse, _get_default_db_config +from .utils import RAG_DEFAULT_GROUP_NAME, gen_docid, is_sparse, _get_default_db_config from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID from .data_type import DataType from .parsing_service import _Processor, DocumentProcessor @@ -68,7 +68,6 @@ class DocImpl: _builtin_node_groups: Dict[str, Dict] = {} _global_node_groups: Dict[str, Dict] = {} _registered_file_reader: Dict[str, Callable] = {} - DEFAULT_GROUP_NAME = '__default__' def __init__(self, embed: Dict[str, Callable], dataset_path: Optional[str] = None, enable_path_monitoring: bool = False, @@ -80,7 +79,7 @@ def __init__(self, embed: Dict[str, Callable], dataset_path: Optional[str] = Non schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): super().__init__() self._local_file_reader: Dict[str, Callable] = {} - self._kb_group_name = kb_group_name or self.DEFAULT_GROUP_NAME + self._kb_group_name = kb_group_name or RAG_DEFAULT_GROUP_NAME self._dataset_path = dataset_path self._enable_path_monitoring = bool(enable_path_monitoring and dataset_path and doc_files is None) self._doc_files = doc_files @@ -546,16 +545,9 @@ def _get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: self._lazy_init() return self._store.get_nodes( - uids=uids, - doc_ids=doc_ids, - group=group, - kb_id=kb_id, - numbers=numbers, - limit=limit, - offset=offset, - return_total=return_total, - sort_by_number=sort_by_number, - display=True, + uids=uids, doc_ids=doc_ids, group=group, kb_id=kb_id, numbers=numbers, + limit=limit, offset=offset, return_total=return_total, + sort_by_number=sort_by_number, display=True, ) def _get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py index d3fd5412c..c3e80fc0e 100644 --- a/lazyllm/tools/rag/doc_service/base.py +++ b/lazyllm/tools/rag/doc_service/base.py @@ -230,8 +230,8 @@ class TaskCallbackPayload(BaseModel): payload: Dict[str, Any] = Field(default_factory=dict) -class KbCreateRequest(BaseModel): - kb_id: str +class KbRequest(BaseModel): + kb_id: Optional[str] = None display_name: Optional[str] = None description: Optional[str] = None owner_id: Optional[str] = None @@ -240,13 +240,8 @@ class KbCreateRequest(BaseModel): idempotency_key: Optional[str] = None -class KbUpdateRequest(BaseModel): - display_name: Optional[str] = None - description: Optional[str] = None - owner_id: Optional[str] = None - meta: Optional[Dict[str, Any]] = None - algo_id: Optional[str] = None - idempotency_key: Optional[str] = None +KbCreateRequest = KbRequest +KbUpdateRequest = KbRequest class KbBatchQueryRequest(BaseModel): diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index 8ba175d52..dbf02447c 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -15,38 +15,16 @@ from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict from ...sql import SqlManager from .base import ( - CALLBACK_RECORDS_TABLE_INFO, - DocServiceError, - DocStatus, - KBStatus, - SourceType, - TaskType, - AddRequest, - DeleteRequest, - MetadataPatchRequest, - ReparseRequest, - TaskCallbackRequest, - TransferRequest, + AddRequest, CALLBACK_RECORDS_TABLE_INFO, CallbackEventType, DOC_SERVICE_TASKS_TABLE_INFO, + DOCUMENTS_TABLE_INFO, DeleteRequest, DocServiceError, DocStatus, IDEMPOTENCY_RECORDS_TABLE_INFO, + KB_ALGORITHM_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KBStatus, MetadataPatchRequest, + PARSE_STATE_TABLE_INFO, ReparseRequest, SourceType, TaskCallbackRequest, TaskType, TransferRequest, UploadRequest, - DOC_SERVICE_TASKS_TABLE_INFO, - DOCUMENTS_TABLE_INFO, - IDEMPOTENCY_RECORDS_TABLE_INFO, - KB_ALGORITHM_TABLE_INFO, - KB_DOCUMENTS_TABLE_INFO, - KBS_TABLE_INFO, - PARSE_STATE_TABLE_INFO, - CallbackEventType, ) from .parser_client import ParserClient from .utils import ( - from_json, - gen_doc_id, - hash_payload, - merge_transfer_metadata, - resolve_transfer_target_path, - sha256_file, - stable_json, - to_json, + from_json, gen_doc_id, hash_payload, merge_transfer_metadata, resolve_transfer_target_path, + sha256_file, stable_json, to_json, ) @@ -81,7 +59,8 @@ def set_callback_url(self, callback_url: str): def _ensure_indexes(self): stmts = [ - 'CREATE UNIQUE INDEX IF NOT EXISTS uq_docs_path ON lazyllm_documents(path)', + 'DROP INDEX IF EXISTS uq_docs_path', + 'CREATE INDEX IF NOT EXISTS idx_docs_path ON lazyllm_documents(path)', 'CREATE INDEX IF NOT EXISTS idx_documents_upload_status ON lazyllm_documents(upload_status)', 'CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON lazyllm_documents(updated_at)', 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_display_name ' @@ -133,17 +112,9 @@ def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) row = session.query(Kb).filter(Kb.kb_id == kb_id).first() if row is None: - row = Kb( - kb_id=kb_id, - display_name=display_name, - description=description, - doc_count=0, - status=KBStatus.ACTIVE.value, - owner_id=owner_id, - meta=to_json(meta), - created_at=now, - updated_at=now, - ) + row = Kb(kb_id=kb_id, display_name=display_name, description=description, doc_count=0, + status=KBStatus.ACTIVE.value, owner_id=owner_id, meta=to_json(meta), + created_at=now, updated_at=now) else: if update_fields is None: update_fields = set() @@ -168,10 +139,9 @@ def _ensure_kb_algorithm(self, kb_id: str, algo_id: str): if row is None: row = Rel(kb_id=kb_id, algo_id=algo_id, created_at=now, updated_at=now) elif row.algo_id != algo_id: - raise DocServiceError( - 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {row.algo_id}', - {'kb_id': kb_id, 'bound_algo_id': row.algo_id, 'requested_algo_id': algo_id} - ) + raise DocServiceError('E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {row.algo_id}', + {'kb_id': kb_id, 'bound_algo_id': row.algo_id, + 'requested_algo_id': algo_id}) else: row.updated_at = now session.add(row) diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py index 8d786f479..1c309de42 100644 --- a/lazyllm/tools/rag/doc_service/doc_server.py +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -52,6 +52,8 @@ def __init__( parser_url: Optional[str] = None, callback_url: Optional[str] = None, ): + if not parser_url: + raise ValueError('parser_url is required; doc_service no longer starts a mock parsing server') self._storage_dir = storage_dir self._db_config = db_config self._parser_db_config = parser_db_config @@ -64,8 +66,6 @@ def __init__( @once_wrapper(reset_on_pickle=True) def _lazy_init(self): os.makedirs(self._storage_dir, exist_ok=True) - if not self._parser_url: - raise ValueError('parser_url is required; doc_service no longer starts a mock parsing server') self._manager = DocManager( db_config=self._db_config, parser_url=self._parser_url, @@ -131,7 +131,7 @@ def _build_upload_payload(request: UploadRequest, file_identities: Optional[List def _build_update_kb_payload(kb_id: str, request: KbUpdateRequest): payload = request.model_dump(mode='json', exclude_unset=True) payload['kb_id'] = kb_id - payload['explicit_fields'] = sorted(request.model_fields_set) + payload['explicit_fields'] = sorted(field for field in request.model_fields_set if field != 'kb_id') return payload def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = None): @@ -257,26 +257,13 @@ async def upload( source_type: Optional[SourceType] = fastapi.Form(None), # noqa: B008 doc_id: Optional[str] = fastapi.Form(None), # noqa: B008 idempotency_key: Optional[str] = fastapi.Form(None), # noqa: B008 - kb_id_query: Optional[str] = fastapi.Query(None, alias='kb_id', include_in_schema=False), # noqa: B008 - algo_id_query: Optional[str] = fastapi.Query(None, alias='algo_id', include_in_schema=False), # noqa: B008 - source_type_query: Optional[SourceType] = fastapi.Query( # noqa: B008 - None, alias='source_type', include_in_schema=False - ), - doc_id_query: Optional[str] = fastapi.Query(None, alias='doc_id', include_in_schema=False), # noqa: B008 - idempotency_key_query: Optional[str] = fastapi.Query( # noqa: B008 - None, - alias='idempotency_key', - include_in_schema=False, - ), ): self._lazy_init() if not files: raise fastapi.HTTPException(status_code=400, detail='files is required') - kb_id = kb_id or kb_id_query or '__default__' - algo_id = algo_id or algo_id_query or '__default__' - source_type = source_type or source_type_query or SourceType.API - doc_id = doc_id or doc_id_query - idempotency_key = idempotency_key or idempotency_key_query + kb_id = kb_id or '__default__' + algo_id = algo_id or '__default__' + source_type = source_type or SourceType.API saved_paths, file_identities = await self._persist_uploads(files) upload_request = UploadRequest( items=[ @@ -457,13 +444,8 @@ def list_chunks( ): self._lazy_init() return self._run(lambda: self._manager.list_chunks( - kb_id=kb_id, - doc_id=doc_id, - group=group, - algo_id=algo_id, - page=page, - page_size=page_size, - offset=offset, + kb_id=kb_id, doc_id=doc_id, group=group, algo_id=algo_id, + page=page, page_size=page_size, offset=offset, )) @app.post('/v1/tasks/batch') @@ -535,6 +517,8 @@ def create_kb_by_id(self, kb_id: str, display_name: Optional[str] = None, descri @app.post('/v1/kbs') def create_kb(self, request: KbCreateRequest): self._lazy_init() + if not request.kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') payload = request.model_dump(mode='json') return self._run(lambda: self._manager.run_idempotent( '/v1/kbs', request.idempotency_key, payload, @@ -550,6 +534,12 @@ def create_kb(self, request: KbCreateRequest): def update_kb_by_id(self, kb_id: str, request: KbUpdateRequest): self._lazy_init() + if request.kb_id and request.kb_id != kb_id: + raise DocServiceError( + 'E_INVALID_PARAM', + f'kb_id mismatch: path={kb_id}, body={request.kb_id}', + {'kb_id': kb_id, 'request_kb_id': request.kb_id}, + ) payload = self._build_update_kb_payload(kb_id, request) return self._run(lambda: self._manager.run_idempotent( f'/v1/kbs/{kb_id}:patch', request.idempotency_key, payload, diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index f97e0a6d9..b8edd172d 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -17,7 +17,7 @@ from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .store.store_base import DEFAULT_KB_ID from .index_base import IndexBase -from .utils import ensure_call_endpoint +from .utils import RAG_DEFAULT_GROUP_NAME, ensure_call_endpoint from .global_metadata import GlobalMetadataDesc as DocField from .web import DocWebModule import copy @@ -44,18 +44,25 @@ def __instancecheck__(self, __instance): class Document(ModuleBase, BuiltinGroups, metaclass=_MetaDocument): class _Manager(ModuleBase): - DEFAULT_GROUP_NAME = '__default__' - - def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - manager: Union[bool, str, DocServer] = False, server: Union[bool, int] = False, - name: Optional[str] = None, - launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, - doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False, - doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None, - display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', - schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, - create_ui: bool = False, - enable_path_monitoring: Optional[bool] = None): + def __init__( # noqa: C901 + self, + dataset_path: Optional[str], + embed: Optional[Union[Callable, Dict[str, Callable]]] = None, + manager: Union[bool, str, DocServer] = False, + server: Union[bool, int] = False, + name: Optional[str] = None, + launcher: Optional[Launcher] = None, + store_conf: Optional[Dict] = None, + doc_fields: Optional[Dict[str, DocField]] = None, + cloud: bool = False, + doc_files: Optional[List[str]] = None, + processor: Optional[DocumentProcessor] = None, + display_name: Optional[str] = '', + description: Optional[str] = 'algorithm description', + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, + create_ui: bool = False, + enable_path_monitoring: Optional[bool] = None, + ): super().__init__() self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud @@ -70,6 +77,9 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._dataset_path = dataset_path self._embed = self._get_embeds(embed) self._processor = processor + self._create_ui = create_ui + self._spawn_doc_server = False + self._doc_processor_started = False compat_ui_manager = manager == 'ui' if compat_ui_manager: lazyllm.LOG.warning('`manager=\'ui\'` is deprecated, use `manager=True, create_ui=True` instead') @@ -81,14 +91,9 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._doc_impl_dataset_path = doc_impl_dataset_path self._doc_processor = None if spawn_doc_server: + self._spawn_doc_server = True self._doc_processor = DocumentProcessor(launcher=self._launcher, pythonpath=_LOCAL_PYTHONPATH) - self._doc_processor.start() - self._manager = DocServer( - launcher=self._launcher, - storage_dir=dataset_path, - parser_url=self._doc_processor.url, - pythonpath=_LOCAL_PYTHONPATH, - ) + self._submodules.remove(self._doc_processor) elif connect_doc_server: self._manager = manager parser_url = getattr(getattr(manager, '_raw_impl', None), '_parser_url', None) @@ -96,11 +101,11 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, parser_url = manager.parser_url if parser_url: self._doc_processor = DocumentProcessor(url=parser_url) - self._schema_extractor = self._register_submodules(schema_extractor) + self._schema_extractor = schema_extractor self._store_conf = store_conf self._display_name = display_name self._description = description - name = name or self.DEFAULT_GROUP_NAME + name = name or RAG_DEFAULT_GROUP_NAME if not display_name: display_name = name if enable_path_monitoring is None: enable_path_monitoring = False if (spawn_doc_server or connect_doc_server or processor) else True @@ -115,7 +120,7 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, store=store_conf, processor=doc_processor, algo_name=name, display_name=display_name, description=description, schema_extractor=schema_extractor)}) - if create_ui: + if create_ui and not self._spawn_doc_server: self.ensure_doc_web() if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server))) self._global_metadata_desc = doc_fields @@ -138,6 +143,8 @@ def web_url(self): def ensure_doc_web(self): if hasattr(self, '_docweb'): return self._docweb + if self._spawn_doc_server and not hasattr(self, '_manager'): + raise ValueError('`create_ui=True` with `manager=True` requires `Document.start()` before using the UI') if not hasattr(self, '_manager') or not isinstance(self._manager, DocServer): raise ValueError( '`create_ui=True` requires an available DocServer. ' @@ -146,33 +153,50 @@ def ensure_doc_web(self): self._docweb = DocWebModule(doc_server=self._manager) return self._docweb + def _ensure_doc_processor_started(self): + if self._doc_processor and not self._doc_processor_started: + self._doc_processor.start() + self._doc_processor_started = True + + def _ensure_managed_services_started(self): + if self._spawn_doc_server: + self._ensure_doc_processor_started() + if not hasattr(self, '_manager'): + self._manager = DocServer( + launcher=self._launcher, + storage_dir=self._dataset_path, + parser_url=self._doc_processor.url, + pythonpath=_LOCAL_PYTHONPATH, + ) + self._manager.start() + if self._create_ui and not hasattr(self, '_docweb'): + self.ensure_doc_web() + self._docweb.start() + + def _get_deploy_tasks(self): + if self._spawn_doc_server and not hasattr(self, '_manager'): + return lazyllm.pipeline(self._ensure_managed_services_started) + return None + def _get_embeds(self, embed): embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {} - return self._register_submodules(embeds) - - def _register_submodules(self, m): - if not m: return m - for embed in (m.values() if isinstance(m, dict) else m if isinstance(m, (tuple, list)) else [m]): - if isinstance(embed, ModuleBase): self._submodules.append(embed) - return m + for index, module in enumerate(embeds.values()): + if isinstance(module, ModuleBase): + setattr(self, f'_embed_module_{index}', module) + return embeds def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, - store_conf: Optional[Dict] = None, - embed: Optional[Union[Callable, Dict[str, Callable]]] = None, + store_conf: Optional[Dict] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): embed = self._get_embeds(embed) if embed else self._embed - schema_extractor = self._register_submodules(schema_extractor) or self._schema_extractor + schema_extractor = schema_extractor or self._schema_extractor + if isinstance(schema_extractor, ModuleBase): + setattr(self, f'_schema_extractor_{name}', schema_extractor) impl = DocImpl( - dataset_path=self._doc_impl_dataset_path, - embed=embed, - kb_group_name=name, - enable_path_monitoring=self._enable_path_monitoring, - global_metadata_desc=doc_fields, - store=store_conf or self._store_conf, - processor=self._doc_processor or self._processor, - algo_name=name, - display_name=name, - description=self._description, + dataset_path=self._doc_impl_dataset_path, embed=embed, kb_group_name=name, + enable_path_monitoring=self._enable_path_monitoring, global_metadata_desc=doc_fields, + store=store_conf or self._store_conf, processor=self._doc_processor or self._processor, + algo_name=name, display_name=name, description=self._description, schema_extractor=schema_extractor, ) (self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl @@ -231,7 +255,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal assert store_conf is None or store_conf['type'] == 'map', ( 'Only map store is supported for Document with temp-files') - name = name or Document._Manager.DEFAULT_GROUP_NAME + name = name or RAG_DEFAULT_GROUP_NAME if isinstance(manager, Document._Manager): assert not server, 'Server infomation is already set to by manager' @@ -254,7 +278,6 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal if _is_local_map_store(store_conf): raise ValueError('`manager=DocumentProcessor(...)` does not support pure local map store') processor, cloud = manager, True - processor.start() manager = False else: cloud, processor = False, None @@ -454,6 +477,9 @@ def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) def _forward(self, func_name: str, *args, **kw): return self._manager(self._curr_group, func_name, *args, **kw) + def start(self): + return super().start() + def find_parent(self, target) -> Callable: return functools.partial(self._forward, 'find_parent', group=target) @@ -503,7 +529,7 @@ def __init__(self, url: str, name: str = None): super().__init__() self._missing_keys = set(dir(Document)) - set(dir(UrlDocument)) self._manager = lazyllm.UrlModule(url=ensure_call_endpoint(url)) - self._curr_group = name or Document._Manager.DEFAULT_GROUP_NAME + self._curr_group = name or RAG_DEFAULT_GROUP_NAME def _forward(self, func_name: str, *args, **kwargs): args = (self._curr_group, func_name, *args) diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index ec5a1b97d..df2b7cb35 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -44,7 +44,8 @@ def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int callback_url: Optional[str] = None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, - high_priority_workers: int = 1): + high_priority_workers: int = 1, callback_task_statuses: Optional[List[str]] = None, + callback_task_types: Optional[List[str]] = None): self._db_config = db_config self._num_workers = num_workers self._post_func = post_func @@ -61,6 +62,8 @@ def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int else [TaskType.DOC_DELETE.value] ) self._high_priority_workers = max(high_priority_workers, 0) + self._callback_task_statuses = callback_task_statuses + self._callback_task_types = callback_task_types self._callback_retry_attempts: Dict[int, int] = {} self._db_manager = None @@ -112,6 +115,8 @@ def _lazy_init(self): lease_renew_interval=self._lease_renew_interval, high_priority_task_types=high_priority_types, high_priority_only=True, + callback_task_statuses=self._callback_task_statuses, + callback_task_types=self._callback_task_types, ) self._high_priority_workers_module.start() if normal_workers > 0: @@ -121,6 +126,8 @@ def _lazy_init(self): lease_duration=self._lease_duration, lease_renew_interval=self._lease_renew_interval, high_priority_task_types=high_priority_types, + callback_task_statuses=self._callback_task_statuses, + callback_task_types=self._callback_task_types, ) self._workers.start() LOG.info('[DocumentProcessor] Lazy initialization completed!') @@ -826,7 +833,9 @@ def __init__(self, port: int = None, url: str = None, num_workers: int = 1, launcher: Optional[Launcher] = None, post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, callback_url: Optional[str] = None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, - high_priority_workers: int = 1, pythonpath: Optional[str] = None): + high_priority_workers: int = 1, pythonpath: Optional[str] = None, + callback_task_statuses: Optional[List[str]] = None, + callback_task_types: Optional[List[str]] = None): super().__init__() self._raw_impl = None # save the reference of the original Impl object self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') @@ -841,7 +850,9 @@ def __init__(self, port: int = None, url: str = None, num_workers: int = 1, lease_renew_interval=lease_renew_interval, high_priority_task_types=high_priority_task_types, high_priority_workers=high_priority_workers, - callback_url=callback_url + callback_url=callback_url, + callback_task_statuses=callback_task_statuses, + callback_task_types=callback_task_types, ) self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher, pythonpath=pythonpath) else: diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 37a6ad38f..9216e551e 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -26,7 +26,8 @@ class DocumentProcessorWorker(ModuleBase): class _Impl(): def __init__(self, db_config: dict = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, - high_priority_only: bool = False, poll_mode: str = 'thread'): + high_priority_only: bool = False, poll_mode: str = 'thread', + callback_task_statuses: list[str] = None, callback_task_types: list[str] = None): self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._shutdown = False self._processors: dict[str, _Processor] = {} # algo_id -> _Processor @@ -49,6 +50,10 @@ def __init__(self, db_config: dict = None, task_poller=None, lease_duration: flo self._lease_renew_interval = lease_renew_interval self._high_priority_task_types = set(high_priority_task_types or []) self._high_priority_only = high_priority_only + self._callback_task_statuses = {TaskStatus(status).value for status in callback_task_statuses} \ + if callback_task_statuses else None + self._callback_task_types = {TaskType(task_type).value for task_type in callback_task_types} \ + if callback_task_types else None @once_wrapper(reset_on_pickle=True) def _lazy_init(self): @@ -587,6 +592,10 @@ def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: Task callback_url: str = None, task_context_json: str = None): try: self._lazy_init() + if self._callback_task_statuses and task_status.value not in self._callback_task_statuses: + return + if self._callback_task_types and task_type not in self._callback_task_types: + return self._finished_task_queue.enqueue( task_id=task_id, task_type=task_type, @@ -716,7 +725,8 @@ def shutdown(self): def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, high_priority_only: bool = False, - poll_mode: str = 'thread'): + poll_mode: str = 'thread', callback_task_statuses: list[str] = None, + callback_task_types: list[str] = None): super().__init__() self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._num_workers = num_workers @@ -729,6 +739,8 @@ def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = Non high_priority_task_types=high_priority_task_types, high_priority_only=high_priority_only, poll_mode=poll_mode, + callback_task_statuses=callback_task_statuses, + callback_task_types=callback_task_types, ) self._worker_impl = ServerModule(worker_impl, port=self._port, num_replicas=self._num_workers) LOG.info(f'[DocumentProcessorWorker] Worker initialized with {num_workers} workers') diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index 0a2716aee..3fc373991 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -45,6 +45,8 @@ config.add('default_dlmanager', str, 'sqlite', 'DEFAULT_DOCLIST_MANAGER', description='The default document list manager for RAG.') +RAG_DEFAULT_GROUP_NAME = '__default__' + def gen_docid(file_path: str) -> str: return hashlib.sha256(file_path.encode()).hexdigest() @@ -104,7 +106,7 @@ class DocPathParsingResult(BaseModel): @deprecated('Document(dataset_path=..., enable_path_monitoring=...)') class DocListManager(ABC): - DEFAULT_GROUP_NAME = '__default__' + DEFAULT_GROUP_NAME = RAG_DEFAULT_GROUP_NAME __pool__ = dict() class Status: diff --git a/tests/basic_tests/RAG/test_doc_manager.py b/tests/basic_tests/RAG/test_doc_manager.py new file mode 100644 index 000000000..ac44e0b03 --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_manager.py @@ -0,0 +1,294 @@ +import pytest +import lazyllm +from lazyllm.tools.rag.utils import DocListManager +from lazyllm.tools.rag.doc_manager import DocManager +import shutil +import hashlib +import sqlite3 +import unittest +import requests +import io +import json +import time + + +@pytest.fixture(autouse=True) +def setup_tmpdir(request, tmpdir): + request.cls.tmpdir = tmpdir + + +def get_fid(path): + if isinstance(path, (tuple, list)): + return type(path)(get_fid(p) for p in path) + else: + return hashlib.sha256(f'{path}'.encode()).hexdigest() + + +@pytest.mark.usefixtures("setup_tmpdir") +class TestDocListManager(unittest.TestCase): + + def setUp(self): + self.test_dir = test_dir = self.tmpdir.mkdir("test_documents") + + test_file_1, test_file_2 = test_dir.join("test1.txt"), test_dir.join("test2.txt") + test_file_1.write("This is a test file 1.") + test_file_2.write("This is a test file 2.") + self.test_file_1, self.test_file_2 = str(test_file_1), str(test_file_2) + + self.manager = DocListManager(str(test_dir), "TestManager") + + def tearDown(self): + shutil.rmtree(str(self.test_dir)) + self.manager.release() + + def test_init_tables(self): + self.manager.init_tables() + assert self.manager.table_inited() is True + + def test_add_files(self): + self.manager.init_tables() + + self.manager.add_files([self.test_file_1, self.test_file_2]) + files_list = self.manager.list_files(details=True) + assert len(files_list) == 2 + assert any(self.test_file_1.endswith(row[1]) for row in files_list) + assert any(self.test_file_2.endswith(row[1]) for row in files_list) + + def test_list_kb_group_files(self): + self.manager.init_tables() + # wait for files to be added + time.sleep(15) + files_list = self.manager.list_kb_group_files(DocListManager.DEFAULT_GROUP_NAME, details=True) + assert len(files_list) == 2 + files_list = self.manager.list_kb_group_files('group1', details=True) + assert len(files_list) == 0 + + self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), + DocListManager.DEFAULT_GROUP_NAME) + files_list = self.manager.list_kb_group_files(DocListManager.DEFAULT_GROUP_NAME, details=True) + assert len(files_list) == 2 + + self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), 'group1') + files_list = self.manager.list_kb_group_files('group1', details=True) + assert len(files_list) == 2 + + def test_list_kb_groups(self): + self.manager.init_tables() + assert len(self.manager.list_all_kb_group()) == 1 + + self.manager.add_kb_group('group1') + self.manager.add_kb_group('group2') + r = self.manager.list_all_kb_group() + assert len(r) == 3 and self.manager.DEFAULT_GROUP_NAME in r and 'group2' in r + + def test_delete_files(self): + self.manager.init_tables() + + self.manager.add_files([self.test_file_1, self.test_file_2]) + self.manager.delete_files([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()]) + files_list = self.manager.list_files(details=True) + assert len(files_list) == 2 + files_list = self.manager.list_files(details=True, exclude_status=DocListManager.Status.deleting) + assert len(files_list) == 1 + assert not any(self.test_file_1.endswith(row[1]) for row in files_list) + + def test_add_deleting_file(self): + self.manager.init_tables() + + self.manager.add_files([self.test_file_1, self.test_file_2]) + self.manager.delete_files([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()]) + files_list = self.manager.list_files(details=True) + assert len(files_list) == 2 + files_list = self.manager.list_files(details=True, status=DocListManager.Status.deleting) + assert len(files_list) == 1 + documents = self.manager.add_files([self.test_file_1]) + assert documents == [] + + def test_update_file_message(self): + self.manager.init_tables() + + self.manager.add_files([self.test_file_1]) + file_id = hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest() + self.manager.update_file_message(file_id, meta="New metadata", status="processed") + + conn = sqlite3.connect(self.manager._db_path) + cursor = conn.execute("SELECT meta, status FROM documents WHERE doc_id = ?", (file_id,)) + row = cursor.fetchone() + conn.close() + + assert row[0] == "New metadata" + assert row[1] == "processed" + + def test_get_and_update_file_status(self): + self.manager.init_tables() + + file_id = hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest() + status = self.manager.get_file_status(file_id) + assert status[0] == DocListManager.Status.success + + self.manager.add_files([self.test_file_1], status=DocListManager.Status.waiting) + status = self.manager.get_file_status(file_id) + assert status[0] == DocListManager.Status.success + + self.manager.update_file_status([file_id], DocListManager.Status.waiting) + status = self.manager.get_file_status(file_id) + assert status[0] == DocListManager.Status.waiting + + def test_add_files_to_kb_group(self): + self.manager.init_tables() + files_list = self.manager.list_kb_group_files("group1", details=True) + assert len(files_list) == 0 + + self.manager.add_files([self.test_file_1, self.test_file_2]) + files_list = self.manager.list_kb_group_files("group1", details=True) + assert len(files_list) == 0 + + self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), group="group1") + files_list = self.manager.list_kb_group_files("group1", details=True) + assert len(files_list) == 2 + + def test_delete_files_from_kb_group(self): + self.manager.init_tables() + + self.manager.add_files([self.test_file_1, self.test_file_2]) + self.manager.add_files_to_kb_group(get_fid([self.test_file_1, self.test_file_2]), group="group1") + + self.manager.delete_files_from_kb_group([hashlib.sha256(f'{self.test_file_1}'.encode()).hexdigest()], "group1") + files_list = self.manager.list_kb_group_files("group1", details=True) + # delete will literally erase the record + assert len(files_list) == 1 + + +@pytest.fixture(scope="class", autouse=True) +def setup_tmpdir_class(request, tmpdir_factory): + request.cls.tmpdir = tmpdir_factory.mktemp("class_tmpdir") + + +@pytest.mark.usefixtures("setup_tmpdir_class") +class TestDocListServer(object): + + @classmethod + def setup_class(cls): + cls.test_dir = test_dir = cls.tmpdir.mkdir("test_server") + + test_file_1, test_file_2 = test_dir.join("test1.txt"), test_dir.join("test2.txt") + test_file_1.write("This is a test file 1.") + test_file_2.write("This is a test file 2.") + cls.test_file_1, cls.test_file_2 = str(test_file_1), str(test_file_2) + + cls.manager = DocListManager(str(test_dir), "TestManager", False) + cls.manager.init_tables() + cls.manager.add_kb_group('group1') + cls.manager.add_kb_group('extra_group') + cls.server = lazyllm.ServerModule(DocManager(cls.manager)) + cls.server.start() + cls._test_inited = True + + test_file_extra = test_dir.join("test_extra.txt") + test_file_extra.write("This is a test file extra.") + cls.test_file_extra = str(test_file_extra) + cls.manager.add_files([cls.test_file_1, cls.test_file_2], status=DocListManager.Status.success) + time.sleep(15) + + def get_url(self, url, **kw): + url = (self.server._url.rsplit("/", 1)[0] + '/' + url).rstrip('/') + if kw: url += ('?' + '&'.join([f'{k}={v}' for k, v in kw.items()])) + return url + + def teardown_class(cls): + cls.server.stop() + shutil.rmtree(str(cls.test_dir)) + cls.manager.release() + + @pytest.mark.order(0) + def test_redirect_to_docs(self): + assert requests.get(self.get_url('')).status_code == 200 + assert requests.get(self.get_url('docs')).status_code == 200 + + @pytest.mark.order(1) + def test_list_kb_groups(self): + response = requests.get(self.get_url('list_kb_groups')) + assert response.status_code == 200 + assert response.json().get('data') == [DocListManager.DEFAULT_GROUP_NAME, 'group1', 'extra_group'] + + @pytest.mark.order(2) + def test_list_files(self): + response = requests.get(self.get_url('list_files')) + assert len(response.json().get('data')) == 2 + response = requests.get(self.get_url('list_files', limit=1)) + assert len(response.json().get('data')) == 1 + response = requests.get(self.get_url('list_files_in_group', group_name=DocListManager.DEFAULT_GROUP_NAME)) + assert len(response.json().get('data')) == 2 + response = requests.get(self.get_url('list_files_in_group', group_name='group1')) + assert len(response.json().get('data')) == 0 + + @pytest.mark.order(3) + def test_upload_files_and_upload_files_to_kb(self): + files = [('files', ('test1.txt', io.BytesIO(b"file1 content"), 'text/plain')), + ('files', ('test2.txt', io.BytesIO(b"file2 content"), 'text/plain'))] + + data = dict(override='true', metadatas=json.dumps([{"key": "value"}, {"key": "value2"}]), user_path='path') + response = requests.post(self.get_url('upload_files', **data), files=files) + assert response.status_code == 200 and response.json().get('code') == 200, response.json() + assert len(response.json().get('data')[0]) == 2 + + response = requests.get(self.get_url('list_files', details=False)) + ids = response.json().get('data') + assert response.status_code == 200 and len(ids) == 4 + + # add_files_to_group + files = [('files', ('test3.txt', io.BytesIO(b"file3 content"), 'text/plain'))] + data = dict(override='false', metadatas=json.dumps([{"key": "value"}]), group_name='group1') + response = requests.post(self.get_url('add_files_to_group', **data), files=files) + assert response.status_code == 200 + + response = requests.get(self.get_url('list_files', details=True)) + assert response.status_code == 200 and len(response.json().get('data')) == 5 + response = requests.get(self.get_url('list_files_in_group', group_name='group1')) + assert response.status_code == 200 and len(response.json().get('data')) == 1 + + @pytest.mark.order(4) + def test_add_files_to_group_and_delete_files_from_group(self): + response = requests.get(self.get_url('list_files', details=False)) + ids = response.json().get('data') + assert response.status_code == 200 and len(ids) == 5 + requests.post(self.get_url('add_files_to_group_by_id'), json=dict(file_ids=ids[:2], group_name='group1')) + response = requests.get(self.get_url('list_files_in_group', group_name='group1')) + assert response.status_code == 200 and len(response.json().get('data')) == 3 + + requests.post(self.get_url('delete_files_from_group'), json=dict(file_ids=ids[:1], group_name='group1')) + response = requests.get(self.get_url('list_files_in_group', group_name='group1')) + assert response.status_code == 200 and len(response.json().get('data')) == 3 + response = requests.get(self.get_url('list_files_in_group', group_name='group1', alive=True)) + assert response.status_code == 200 and len(response.json().get('data')) == 2 + + @pytest.mark.order(5) + def test_delete_files(self): + response = requests.get(self.get_url('list_files', details=False)) + ids = response.json().get('data') + assert response.status_code == 200 and len(ids) == 5 + + response = requests.post(self.get_url('delete_files'), json=dict(file_ids=ids[-1:])) + lazyllm.LOG.warning(response.json()) + assert response.status_code == 200 and response.json().get('code') == 200 + + response = requests.get(self.get_url('list_files')) + assert response.status_code == 200 and len(response.json().get('data')) == 5 + response = requests.get(self.get_url('list_files', alive=True)) + assert response.status_code == 200 and len(response.json().get('data')) == 4 + + response = requests.get(self.get_url('list_files_in_group', group_name='group1')) + assert response.status_code == 200 and len(response.json().get('data')) == 3 + response = requests.get(self.get_url('list_files_in_group', group_name='group1', alive=True)) + assert response.status_code == 200 and len(response.json().get('data')) == 1 + + @pytest.mark.order(6) + def test_add_files(self): + json_data = { + 'files': [self.test_file_extra, "fake path"], + 'group_name': "extra_group", + 'metadatas': json.dumps([{"key": "value"}, {"key": "value"}]) + } + response = requests.post(self.get_url('add_files'), json=json_data) + assert response.status_code == 200 + assert len(response.json().get('data')) == 2 and response.json().get('data')[1] is None diff --git a/tests/basic_tests/RAG/test_doc_processor.py b/tests/basic_tests/RAG/test_doc_processor.py index 30d0a91ae..eefcada2b 100644 --- a/tests/basic_tests/RAG/test_doc_processor.py +++ b/tests/basic_tests/RAG/test_doc_processor.py @@ -6,7 +6,8 @@ import pytest from lazyllm.tools.rag.parsing_service import DocumentProcessor -from lazyllm.tools.rag.parsing_service.base import TaskStatus +from lazyllm.tools.rag.parsing_service.base import TaskStatus, TaskType +from lazyllm.tools.rag.parsing_service.worker import DocumentProcessorWorker from lazyllm import Document, Retriever STATIC_STATUS = [TaskStatus.SUCCESS.value, TaskStatus.FAILED.value, TaskStatus.CANCELED.value] @@ -132,6 +133,49 @@ def test_upload_doc(self): except requests.exceptions.RequestException as e: self.fail(f'Request failed: {e}') + +class _FakeFinishedTaskQueue: + def __init__(self): + self.items = [] + + def enqueue(self, **kwargs): + self.items.append(kwargs) + + +def test_worker_callback_filters_allow_all_by_default(): + impl = DocumentProcessorWorker._Impl() + queue = _FakeFinishedTaskQueue() + impl._lazy_init = lambda: None + impl._finished_task_queue = queue + + impl._enqueue_finished_task( + task_id='task-1', + task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.WORKING, + callback_url='http://callback.test', + ) + + assert [item['task_id'] for item in queue.items] == ['task-1'] + + +def test_worker_callback_filters_skip_non_matching_records(): + impl = DocumentProcessorWorker._Impl( + callback_task_statuses=[TaskStatus.SUCCESS.value], + callback_task_types=[TaskType.DOC_ADD.value], + ) + queue = _FakeFinishedTaskQueue() + impl._lazy_init = lambda: None + impl._finished_task_queue = queue + + impl._enqueue_finished_task(task_id='task-delete', task_type=TaskType.DOC_DELETE.value, + task_status=TaskStatus.SUCCESS) + impl._enqueue_finished_task(task_id='task-working', task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.WORKING) + impl._enqueue_finished_task(task_id='task-add-success', task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.SUCCESS) + + assert [item['task_id'] for item in queue.items] == ['task-add-success'] + @pytest.mark.order(2) def test_reparse(self): with open(self._file_path, 'w') as f: diff --git a/tests/basic_tests/RAG/test_doc_service_doc_manager.py b/tests/basic_tests/RAG/test_doc_service_doc_manager.py index 7e92df718..9ace9c6be 100644 --- a/tests/basic_tests/RAG/test_doc_service_doc_manager.py +++ b/tests/basic_tests/RAG/test_doc_service_doc_manager.py @@ -20,7 +20,8 @@ TransferRequest, UploadRequest, ) -from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient +from lazyllm.tools.rag.doc_service.doc_manager import DocManager +from lazyllm.tools.rag.doc_service.parser_client import ParserClient from lazyllm.tools.rag.parsing_service.base import TaskType from lazyllm.tools.rag.utils import BaseResponse @@ -38,7 +39,12 @@ def __init__(self): 'port': None, 'db_name': os.path.join(self.tmp_dir, 'doc_service_local.db'), } - self.manager = DocManager(db_config=self.db_config, parser_url='http://parser.test') + original_health = ParserClient.health + ParserClient.health = lambda self: BaseResponse(code=200, msg='success', data={'ok': True}) + try: + self.manager = DocManager(db_config=self.db_config, parser_url='http://parser.test') + finally: + ParserClient.health = original_health self.pending_task_status = {} self.cancel_calls = [] self.delete_calls = [] @@ -331,6 +337,35 @@ def test_manager_transfer_uses_target_doc_id_for_target_records(manager_harness) assert manager_harness.manager._has_kb_document('kb_transfer_source', 'source-doc') is True +def test_manager_transfer_same_kb_copy_reuses_source_path(manager_harness): + manager_harness.manager.create_kb('kb_same_transfer', algo_id='__default__') + file_path = manager_harness.make_file('same-transfer.txt', 'same transfer content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_same_transfer', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='same-source-doc')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='same-source-doc', + target_doc_id='same-target-doc', + source_kb_id='kb_same_transfer', + source_algo_id='__default__', + target_kb_id='kb_same_transfer', + target_algo_id='__default__', + mode='copy', + )])) + + assert items[0]['accepted'] is True + assert items[0]['status'] == DocStatus.WAITING.value + assert items[0]['target_file_path'] == file_path + assert manager_harness.add_doc_calls[-1]['doc_id'] == 'same-source-doc' + assert manager_harness.add_doc_calls[-1]['transfer_params']['target_doc_id'] == 'same-target-doc' + assert manager_harness.manager._has_kb_document('kb_same_transfer', 'same-source-doc') is True + assert manager_harness.manager._has_kb_document('kb_same_transfer', 'same-target-doc') is True + + def test_manager_transfer_move_cleans_source_doc_with_target_doc_id(manager_harness): manager_harness.manager.create_kb('kb_move_source', algo_id='__default__') manager_harness.manager.create_kb('kb_move_target', algo_id='__default__') @@ -520,10 +555,12 @@ def test_manager_callback_payload_fallback_and_delete_transition(manager_harness def test_parser_client_algo_endpoint_fallback(): - client = _ParserClient(parser_url='http://parser.test') + client = ParserClient(parser_url='http://parser.test') calls = [] - def fake_get(path, params=None): + def fake_request(method, path, params=None, **kwargs): + assert method == 'GET' + assert kwargs == {} del params calls.append(path) if path == '/v1/algo/list': @@ -544,7 +581,7 @@ def fake_get(path, params=None): } raise AssertionError(path) - client._get = fake_get + client._request = fake_request algo_resp = client.list_algorithms() group_resp = client.get_algorithm_groups('__default__') diff --git a/tests/basic_tests/RAG/test_doc_service_mock.py b/tests/basic_tests/RAG/test_doc_service_mock.py deleted file mode 100644 index 7a3e67b34..000000000 --- a/tests/basic_tests/RAG/test_doc_service_mock.py +++ /dev/null @@ -1,994 +0,0 @@ -import io -import os -import socket -import shutil -import tempfile -import time -from concurrent.futures import ThreadPoolExecutor -from datetime import datetime -from uuid import uuid4 - -import pytest -import requests - -from lazyllm.tools.rag.doc_service import DocServer -from lazyllm.tools.rag.doc_service.base import ( - AddFileItem, CallbackEventType, DeleteRequest, DocServiceError, DocStatus, KbUpdateRequest, ReparseRequest, - SourceType, TaskCallbackRequest, UploadRequest, -) -from lazyllm.tools.rag.doc_service.doc_manager import DocManager, _ParserClient -from lazyllm.tools.rag.parsing_service.base import TaskType -from lazyllm.tools.rag.utils import BaseResponse - - -@pytest.mark.skip_on_win -class TestDocServiceMock: - @staticmethod - def _ensure_bindable(): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - sock.bind(('127.0.0.1', 0)) - except OSError as exc: - if 'operation not permitted' in str(exc).lower(): - pytest.skip('Socket bind is not permitted in current environment') - raise - finally: - sock.close() - - @classmethod - def setup_class(cls): - cls._ensure_bindable() - cls._parser_url = os.getenv('LAZYLLM_DOC_SERVICE_TEST_PARSER_URL') - if not cls._parser_url: - pytest.skip('LAZYLLM_DOC_SERVICE_TEST_PARSER_URL is required for real parser integration tests') - cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_') - cls._storage_dir = os.path.join(cls._tmp_dir, 'uploads') - os.makedirs(cls._storage_dir, exist_ok=True) - - cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') - with open(cls._seed_path, 'w', encoding='utf-8') as f: - f.write('seed content') - - cls._db_config = { - 'db_type': 'sqlite', - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'db_name': os.path.join(cls._tmp_dir, 'doc_service.db'), - } - cls.server = DocServer( - db_config=cls._db_config, - parser_url=cls._parser_url, - storage_dir=cls._storage_dir, - ) - cls.server.start() - cls.base_url = cls.server._impl._url.rsplit('/', 1)[0] - deadline = time.time() + 10 - while time.time() < deadline: - try: - resp = requests.get(f'{cls.base_url}/v1/health', timeout=3) - if resp.status_code == 200: - break - except Exception: - pass - time.sleep(0.2) - - @classmethod - def teardown_class(cls): - cls.server.stop() - shutil.rmtree(cls._tmp_dir, ignore_errors=True) - - def _wait_task(self, task_id, target_statuses, timeout=8): - deadline = time.time() + timeout - last = None - while time.time() < deadline: - resp = requests.get(f'{self.base_url}/v1/tasks/{task_id}', timeout=5) - assert resp.status_code == 200 - last = resp.json()['data'] - if last['status'] in target_statuses: - return last - time.sleep(0.1) - raise AssertionError(f'task {task_id} not finished in time, last={last}') - - def test_p0_endpoints_and_core_flows(self): - health = requests.get(f'{self.base_url}/v1/health', timeout=5) - assert health.status_code == 200 - - kb_create = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_a'}, timeout=5) - assert kb_create.status_code == 200 - assert kb_create.json()['data']['kb_id'] == 'kb_a' - - kb_list = requests.get(f'{self.base_url}/v1/kbs', timeout=5) - assert kb_list.status_code == 200 - assert any(item['kb_id'] == 'kb_a' for item in kb_list.json()['data']['items']) - - algo_list = requests.get(f'{self.base_url}/v1/algo/list', timeout=5) - assert algo_list.status_code == 200 - assert any(item['algo_id'] == '__default__' for item in algo_list.json()['data']) - - algo_groups = requests.get(f'{self.base_url}/v1/algo/__default__/groups', timeout=5) - assert algo_groups.status_code == 200 - assert len(algo_groups.json()['data']) > 0 - - upload = requests.post( - f'{self.base_url}/v1/docs/upload', - params={'kb_id': 'kb_a', 'algo_id': '__default__'}, - files=[('files', ('upload.txt', io.BytesIO(b'upload content'), 'text/plain'))], - timeout=8, - ) - assert upload.status_code == 200 - upload_item = upload.json()['data']['items'][0] - doc_upload = upload_item['doc_id'] - upload_task = upload_item['task_id'] - self._wait_task(upload_task, {'SUCCESS'}) - - add = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_a', - 'algo_id': '__default__', - 'items': [{'file_path': self._seed_path, 'doc_id': 'seed-doc-1', 'metadata': {'owner': 'u1'}}], - }, - timeout=8, - ) - assert add.status_code == 200 - add_item = add.json()['data']['items'][0] - doc_add = add_item['doc_id'] - add_task = add_item['task_id'] - self._wait_task(add_task, {'SUCCESS'}) - - meta_patch = requests.post( - f'{self.base_url}/v1/docs/metadata/patch', - json={ - 'kb_id': 'kb_a', - 'algo_id': '__default__', - 'items': [{'doc_id': doc_add, 'patch': {'tag': 'patched'}}], - }, - timeout=8, - ) - assert meta_patch.status_code == 200 - meta_task = meta_patch.json()['data']['items'][0]['task_id'] - self._wait_task(meta_task, {'SUCCESS'}) - - reparse = requests.post( - f'{self.base_url}/v1/docs/reparse', - json={'kb_id': 'kb_a', 'algo_id': '__default__', 'doc_ids': [doc_add]}, - timeout=8, - ) - assert reparse.status_code == 200 - reparse_task = reparse.json()['data']['task_ids'][0] - self._wait_task(reparse_task, {'SUCCESS'}) - - kb_b = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_b'}, timeout=5) - assert kb_b.status_code == 200 - - transfer = requests.post( - f'{self.base_url}/v1/docs/transfer', - json={ - 'items': [ - { - 'doc_id': doc_add, - 'source_kb_id': 'kb_a', - 'source_algo_id': '__default__', - 'target_kb_id': 'kb_b', - 'target_algo_id': '__default__', - 'mode': 'copy', - } - ] - }, - timeout=8, - ) - assert transfer.status_code == 200 - transfer_task = transfer.json()['data']['items'][0]['task_id'] - self._wait_task(transfer_task, {'SUCCESS'}) - - docs = requests.get( - f'{self.base_url}/v1/docs', - params={'kb_id': 'kb_a', 'include_deleted_or_canceled': True, 'keyword': 'seed'}, - timeout=8, - ) - assert docs.status_code == 200 - assert docs.json()['data']['total'] >= 1 - - doc_detail = requests.get(f'{self.base_url}/v1/docs/{doc_add}', timeout=8) - assert doc_detail.status_code == 200 - assert doc_detail.json()['data']['doc']['metadata'].get('tag') == 'patched' - - tasks = requests.get(f'{self.base_url}/v1/tasks', params={'status': ['SUCCESS', 'WAITING']}, timeout=8) - assert tasks.status_code == 200 - assert tasks.json()['data']['total'] >= 1 - - task_detail = requests.get(f'{self.base_url}/v1/tasks/{reparse_task}', timeout=8) - assert task_detail.status_code == 200 - - cancel = requests.post(f'{self.base_url}/v1/tasks/cancel', json={'task_id': reparse_task}, timeout=8) - assert cancel.status_code == 200 - assert cancel.json()['data']['task_id'] == reparse_task - - delete = requests.post( - f'{self.base_url}/v1/docs/delete', - json={'kb_id': 'kb_a', 'algo_id': '__default__', 'doc_ids': [doc_upload]}, - timeout=8, - ) - assert delete.status_code == 200 - delete_task = delete.json()['data']['items'][0]['task_id'] - self._wait_task(delete_task, {'DELETED'}) - - docs_filtered = requests.get( - f'{self.base_url}/v1/docs', - params={'kb_id': 'kb_a', 'include_deleted_or_canceled': False}, - timeout=8, - ) - assert docs_filtered.status_code == 200 - - cb = requests.post( - f'{self.base_url}/v1/internal/callbacks/tasks', - json={ - 'task_id': 'non-exist-task', - 'event_type': 'FINISH', - 'status': 'SUCCESS', - 'payload': {'task_type': 'DOC_ADD', 'doc_id': 'nope', 'kb_id': 'kb_a', 'algo_id': '__default__'}, - }, - timeout=8, - ) - assert cb.status_code == 200 - assert cb.json()['data']['ack'] is True - - kb_delete = requests.delete(f'{self.base_url}/v1/kbs/kb_a', timeout=8) - assert kb_delete.status_code == 200 - - def test_missing_p0_endpoints_exist(self): - kb_create = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_endpoints'}, timeout=5) - assert kb_create.status_code == 200 - - chunks = requests.get(f'{self.base_url}/v1/chunks', timeout=5) - assert chunks.status_code == 200 - assert chunks.json()['data']['items'] == [] - - algorithms = requests.get(f'{self.base_url}/v1/algorithms', timeout=5) - assert algorithms.status_code == 200 - assert len(algorithms.json()['data']['items']) >= 1 - - algo_info = requests.post( - f'{self.base_url}/v1/algorithms/info', json={'algo_id': '__default__'}, timeout=5, - ) - assert algo_info.status_code == 200 - assert algo_info.json()['data']['algo_id'] == '__default__' - - add = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_endpoints', - 'algo_id': '__default__', - 'items': [{'file_path': self._seed_path, 'doc_id': 'seed-doc-endpoints'}], - }, - timeout=8, - ) - assert add.status_code == 200 - task_id = add.json()['data']['items'][0]['task_id'] - - task_info = requests.post(f'{self.base_url}/v1/tasks/info', json={'task_id': task_id}, timeout=5) - assert task_info.status_code == 200 - assert task_info.json()['data']['task_id'] == task_id - - task_batch = requests.post(f'{self.base_url}/v1/tasks/batch', json={'task_ids': [task_id]}, timeout=5) - assert task_batch.status_code == 200 - assert len(task_batch.json()['data']['items']) == 1 - - kb_delete = requests.delete(f'{self.base_url}/v1/kbs', json={'kb_ids': ['kb_endpoints']}, timeout=8) - assert kb_delete.status_code == 200 - assert len(kb_delete.json()['data']['items']) == 1 - - def test_idempotency_replay_and_conflict(self): - file_path = os.path.join(self._tmp_dir, 'idem.txt') - with open(file_path, 'w', encoding='utf-8') as f: - f.write('idempotent content') - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_idem'}, timeout=5) - assert create_kb.status_code == 200 - - payload = { - 'kb_id': 'kb_idem', - 'algo_id': '__default__', - 'idempotency_key': 'idem-add-key', - 'items': [{'file_path': file_path, 'doc_id': 'idem-doc-1'}], - } - first = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) - second = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) - assert first.status_code == 200 - assert second.status_code == 200 - assert first.json()['data']['items'][0]['task_id'] == second.json()['data']['items'][0]['task_id'] - - conflict_payload = dict(payload) - conflict_payload['items'] = [{'file_path': file_path, 'doc_id': 'idem-doc-2'}] - conflict = requests.post(f'{self.base_url}/v1/docs/add', json=conflict_payload, timeout=8) - assert conflict.status_code == 409 - assert conflict.json()['data']['biz_code'] == 'E_IDEMPOTENCY_CONFLICT' - - def test_upload_idempotency_replay_and_conflict(self): - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_upload_idem'}, timeout=5) - assert create_kb.status_code == 200 - - params = { - 'kb_id': 'kb_upload_idem', - 'algo_id': '__default__', - 'idempotency_key': 'idem-upload-key', - } - first = requests.post( - f'{self.base_url}/v1/docs/upload', - params=params, - files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem content'), 'text/plain'))], - timeout=8, - ) - second = requests.post( - f'{self.base_url}/v1/docs/upload', - params=params, - files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem content'), 'text/plain'))], - timeout=8, - ) - assert first.status_code == 200 - assert second.status_code == 200 - assert first.json()['data']['items'][0]['task_id'] == second.json()['data']['items'][0]['task_id'] - - conflict = requests.post( - f'{self.base_url}/v1/docs/upload', - params=params, - files=[('files', ('idem-upload.txt', io.BytesIO(b'upload idem changed'), 'text/plain'))], - timeout=8, - ) - assert conflict.status_code == 409 - assert conflict.json()['data']['biz_code'] == 'E_IDEMPOTENCY_CONFLICT' - - def test_add_same_path_with_different_doc_id_returns_conflict(self): - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_path_conflict'}, timeout=5) - assert create_kb.status_code == 200 - - first = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_path_conflict', - 'algo_id': '__default__', - 'items': [{'file_path': self._seed_path, 'doc_id': 'path-doc-1'}], - }, - timeout=8, - ) - assert first.status_code == 200 - - conflict = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_path_conflict', - 'algo_id': '__default__', - 'items': [{'file_path': self._seed_path, 'doc_id': 'path-doc-2'}], - }, - timeout=8, - ) - assert conflict.status_code == 409 - body = conflict.json() - assert body['data']['biz_code'] == 'E_STATE_CONFLICT' - assert body['data']['path'] == self._seed_path - assert body['data']['doc_id'] == 'path-doc-1' - - def test_upload_same_filename_does_not_override_existing_file(self): - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_same_name'}, timeout=5) - assert create_kb.status_code == 200 - - first = requests.post( - f'{self.base_url}/v1/docs/upload', - params={'kb_id': 'kb_same_name', 'algo_id': '__default__'}, - files=[('files', ('same-name.txt', io.BytesIO(b'first content'), 'text/plain'))], - timeout=8, - ) - second = requests.post( - f'{self.base_url}/v1/docs/upload', - params={'kb_id': 'kb_same_name', 'algo_id': '__default__'}, - files=[('files', ('same-name.txt', io.BytesIO(b'second content'), 'text/plain'))], - timeout=8, - ) - assert first.status_code == 200 - assert second.status_code == 200 - - first_item = first.json()['data']['items'][0] - second_item = second.json()['data']['items'][0] - assert first_item['doc_id'] != second_item['doc_id'] - self._wait_task(first_item['task_id'], {'SUCCESS'}) - self._wait_task(second_item['task_id'], {'SUCCESS'}) - - first_detail = requests.get(f'{self.base_url}/v1/docs/{first_item["doc_id"]}', timeout=8) - second_detail = requests.get(f'{self.base_url}/v1/docs/{second_item["doc_id"]}', timeout=8) - assert first_detail.status_code == 200 - assert second_detail.status_code == 200 - assert first_detail.json()['data']['doc']['path'] != second_detail.json()['data']['doc']['path'] - - def test_idempotency_atomic_claim(self): - file_path = os.path.join(self._tmp_dir, 'idem_atomic.txt') - with open(file_path, 'w', encoding='utf-8') as f: - f.write('idempotent atomic content') - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_idem_atomic'}, timeout=5) - assert create_kb.status_code == 200 - - payload = { - 'kb_id': 'kb_idem_atomic', - 'algo_id': '__default__', - 'idempotency_key': 'idem-atomic-key', - 'items': [{'file_path': file_path, 'doc_id': 'idem-atomic-doc'}], - } - - def _send(): - return requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) - - with ThreadPoolExecutor(max_workers=6) as pool: - responses = list(pool.map(lambda _: _send(), range(6))) - - statuses = [resp.status_code for resp in responses] - assert all(status in (200, 409) for status in statuses) - success_payloads = [resp.json()['data'] for resp in responses if resp.status_code == 200] - unique_task_ids = {item['items'][0]['task_id'] for item in success_payloads} - assert len(unique_task_ids) == 1 - for resp in responses: - if resp.status_code == 409: - assert resp.json()['data']['biz_code'] in {'E_IDEMPOTENCY_IN_PROGRESS', 'E_IDEMPOTENCY_CONFLICT'} - - replay = requests.post(f'{self.base_url}/v1/docs/add', json=payload, timeout=8) - assert replay.status_code == 200 - assert replay.json()['data']['items'][0]['task_id'] in unique_task_ids - - def test_illegal_state_transition(self): - file_path = os.path.join(self._tmp_dir, 'illegal.txt') - with open(file_path, 'w', encoding='utf-8') as f: - f.write('illegal transition content') - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_illegal'}, timeout=5) - assert create_kb.status_code == 200 - - add = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_illegal', - 'algo_id': '__default__', - 'items': [{'file_path': file_path, 'doc_id': 'illegal-doc-1'}], - }, - timeout=8, - ) - assert add.status_code == 200 - doc_id = add.json()['data']['items'][0]['doc_id'] - task_id = add.json()['data']['items'][0]['task_id'] - self._wait_task(task_id, {'SUCCESS'}) - - delete = requests.post( - f'{self.base_url}/v1/docs/delete', - json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=8, - ) - assert delete.status_code == 200 - - reparse_while_deleting = requests.post( - f'{self.base_url}/v1/docs/reparse', - json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=8, - ) - assert reparse_while_deleting.status_code == 409 - assert reparse_while_deleting.json()['data']['biz_code'] == 'E_STATE_CONFLICT' - - delete_again = requests.post( - f'{self.base_url}/v1/docs/delete', - json={'kb_id': 'kb_illegal', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=8, - ) - assert delete_again.status_code == 409 - - add_again = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_illegal', - 'algo_id': '__default__', - 'items': [{'file_path': file_path, 'doc_id': doc_id}], - }, - timeout=8, - ) - assert add_again.status_code == 409 - - def test_kb_algo_binding_and_transfer_validation(self): - create_kb = requests.post( - f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_bind', 'algo_id': '__default__'}, timeout=5, - ) - assert create_kb.status_code == 200 - - rebind = requests.post( - f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_bind', 'algo_id': 'another_algo'}, timeout=5, - ) - assert rebind.status_code == 409 - assert rebind.json()['data']['biz_code'] == 'E_STATE_CONFLICT' - - file_path = os.path.join(self._tmp_dir, 'binding.txt') - with open(file_path, 'w', encoding='utf-8') as f: - f.write('binding content') - - mismatch = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_bind', - 'algo_id': 'another_algo', - 'items': [{'file_path': file_path, 'doc_id': 'bind-doc'}], - }, - timeout=8, - ) - assert mismatch.status_code == 400 - assert mismatch.json()['data']['biz_code'] == 'E_INVALID_PARAM' - - add = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_bind', - 'algo_id': '__default__', - 'items': [{'file_path': file_path, 'doc_id': 'bind-doc'}], - }, - timeout=8, - ) - assert add.status_code == 200 - doc_id = add.json()['data']['items'][0]['doc_id'] - self._wait_task(add.json()['data']['items'][0]['task_id'], {'SUCCESS'}) - - invalid_transfer = requests.post( - f'{self.base_url}/v1/docs/transfer', - json={ - 'items': [{ - 'doc_id': doc_id, - 'source_kb_id': 'kb_bind', - 'source_algo_id': '__default__', - 'target_kb_id': 'kb_bind', - 'target_algo_id': '__default__', - 'mode': 'invalid', - }] - }, - timeout=8, - ) - assert invalid_transfer.status_code == 400 - assert invalid_transfer.json()['data']['biz_code'] == 'E_INVALID_PARAM' - - def test_stale_callback_ignored(self): - file_path = os.path.join(self._tmp_dir, 'stale.txt') - with open(file_path, 'w', encoding='utf-8') as f: - f.write('stale callback content') - create_kb = requests.post(f'{self.base_url}/v1/kbs', json={'kb_id': 'kb_stale'}, timeout=5) - assert create_kb.status_code == 200 - - add = requests.post( - f'{self.base_url}/v1/docs/add', - json={ - 'kb_id': 'kb_stale', - 'algo_id': '__default__', - 'items': [{'file_path': file_path, 'doc_id': 'stale-doc-1'}], - }, - timeout=8, - ) - assert add.status_code == 200 - doc_id = add.json()['data']['items'][0]['doc_id'] - self._wait_task(add.json()['data']['items'][0]['task_id'], {'SUCCESS'}) - - first = requests.post( - f'{self.base_url}/v1/docs/reparse', - json={'kb_id': 'kb_stale', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=8, - ) - assert first.status_code == 200 - first_task_id = first.json()['data']['task_ids'][0] - - second = requests.post( - f'{self.base_url}/v1/docs/reparse', - json={'kb_id': 'kb_stale', 'algo_id': '__default__', 'doc_ids': [doc_id]}, - timeout=8, - ) - assert second.status_code == 200 - second_task_id = second.json()['data']['task_ids'][0] - assert first_task_id != second_task_id - - stale = requests.post( - f'{self.base_url}/v1/internal/callbacks/tasks', - json={ - 'callback_id': 'stale-callback-1', - 'task_id': first_task_id, - 'event_type': 'FINISH', - 'status': 'SUCCESS', - }, - timeout=8, - ) - assert stale.status_code == 200 - assert stale.json()['data']['ignored_reason'] == 'stale_task_callback' - - duplicate = requests.post( - f'{self.base_url}/v1/internal/callbacks/tasks', - json={ - 'callback_id': 'stale-callback-1', - 'task_id': first_task_id, - 'event_type': 'FINISH', - 'status': 'SUCCESS', - }, - timeout=8, - ) - assert duplicate.status_code == 200 - assert duplicate.json()['data']['deduped'] is True - - def test_get_doc_404_is_wrapped(self): - resp = requests.get(f'{self.base_url}/v1/docs/not-exists-doc', timeout=5) - assert resp.status_code == 404 - body = resp.json() - assert body['code'] == 404 - assert body['data']['biz_code'] == 'E_NOT_FOUND' - - def test_delete_kbs_empty_payload_returns_400(self): - resp = requests.delete(f'{self.base_url}/v1/kbs', json={'kb_ids': []}, timeout=5) - assert resp.status_code == 400 - assert resp.json()['data']['biz_code'] == 'E_INVALID_PARAM' - - def test_kb_update_pagination_and_batch_query(self): - first = requests.post( - f'{self.base_url}/v1/kbs', - json={'kb_id': 'kb_page_1', 'display_name': 'Page 1', 'algo_id': '__default__'}, - timeout=5, - ) - second = requests.post( - f'{self.base_url}/v1/kbs', - json={'kb_id': 'kb_page_2', 'display_name': 'Page 2', 'algo_id': '__default__'}, - timeout=5, - ) - assert first.status_code == 200 - assert second.status_code == 200 - - paged = requests.get(f'{self.base_url}/v1/kbs', params={'page': 1, 'page_size': 1}, timeout=5) - assert paged.status_code == 200 - paged_data = paged.json()['data'] - assert paged_data['page'] == 1 - assert paged_data['page_size'] == 1 - assert paged_data['total'] >= 2 - assert len(paged_data['items']) == 1 - - detail = requests.get(f'{self.base_url}/v1/kbs/kb_page_1', timeout=5) - assert detail.status_code == 200 - assert detail.json()['data']['algo_id'] == '__default__' - - updated = requests.post( - f'{self.base_url}/v1/kbs/kb_page_1/update', - json={ - 'display_name': 'Page 1 Updated', - 'description': 'updated description', - 'owner_id': 'owner-a', - 'meta': {'scene': 'pagination-test'}, - }, - timeout=5, - ) - assert updated.status_code == 200 - updated_data = updated.json()['data'] - assert updated_data['display_name'] == 'Page 1 Updated' - assert updated_data['meta']['scene'] == 'pagination-test' - - batch = requests.post( - f'{self.base_url}/v1/kbs/batch', - json={'kb_ids': ['kb_page_1', 'kb_missing']}, - timeout=5, - ) - assert batch.status_code == 200 - batch_data = batch.json()['data'] - assert len(batch_data['items']) == 1 - assert batch_data['items'][0]['kb_id'] == 'kb_page_1' - assert batch_data['missing_kb_ids'] == ['kb_missing'] - - -class TestDocServiceMockLocal: - @classmethod - def setup_class(cls): - cls._tmp_dir = tempfile.mkdtemp(prefix='lazyllm_doc_service_local_') - cls._seed_path = os.path.join(cls._tmp_dir, 'seed.txt') - with open(cls._seed_path, 'w', encoding='utf-8') as f: - f.write('local seed content') - cls._db_config = { - 'db_type': 'sqlite', - 'user': None, - 'password': None, - 'host': None, - 'port': None, - 'db_name': os.path.join(cls._tmp_dir, 'doc_service_local.db'), - } - cls.manager = DocManager(db_config=cls._db_config, parser_url='http://parser.test') - cls._pending_task_status = {} - - def _queue_task(task_id: str, final_status: DocStatus): - cls._pending_task_status[task_id] = final_status - - def _add_doc(task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None): - _queue_task(task_id, DocStatus.SUCCESS) - return BaseResponse( - code=200, - msg='success', - data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, - ) - - cls.manager._parser_client.add_doc = _add_doc - - def _update_meta(task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None): - _queue_task(task_id, DocStatus.SUCCESS) - return BaseResponse( - code=200, - msg='success', - data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, - ) - - cls.manager._parser_client.update_meta = _update_meta - - def _delete_doc(task_id, algo_id, kb_id, doc_id): - _queue_task(task_id, DocStatus.SUCCESS) - return BaseResponse( - code=200, - msg='success', - data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}, - ) - - cls.manager._parser_client.delete_doc = _delete_doc - cls.manager._parser_client.cancel_task = lambda task_id: BaseResponse( - code=200, msg='success', data={'task_id': task_id, 'cancel_status': True} - ) - cls.manager._parser_client.list_algorithms = lambda: BaseResponse( - code=200, msg='success', data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}] - ) - cls.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( - code=200, - msg='success', - data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, - ) - - @classmethod - def teardown_class(cls): - shutil.rmtree(cls._tmp_dir, ignore_errors=True) - - def _wait_task(self, task_id, target_statuses, timeout=8): - deadline = time.time() + timeout - last = None - while time.time() < deadline: - resp = self.manager.get_task(task_id) - assert resp.code == 200 - last = resp.data - if last['status'] in target_statuses: - return last - pending_status = self._pending_task_status.pop(task_id, None) - if pending_status is not None: - self.manager.on_task_callback(TaskCallbackRequest( - task_id=task_id, - event_type=CallbackEventType.FINISH, - status=pending_status, - )) - time.sleep(0.05) - raise AssertionError(f'task {task_id} not finished in time, last={last}') - - def _make_file(self, name: str, content: str): - file_path = os.path.join(self._tmp_dir, name) - with open(file_path, 'w', encoding='utf-8') as f: - f.write(content) - return file_path - - def test_manager_atomic_idempotency(self): - started = [] - - def handler(): - started.append(time.time()) - time.sleep(0.2) - return {'task_id': str(uuid4())} - - with ThreadPoolExecutor(max_workers=2) as pool: - future = pool.submit(self.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) - time.sleep(0.05) - with pytest.raises(DocServiceError) as exc: - self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) - result = future.result(timeout=2) - - assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' - replay = self.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) - assert len(started) == 1 - assert replay == result - - def test_manager_kb_algo_binding(self): - self.manager.create_kb('kb_local_bind', algo_id='__default__') - file_path = self._make_file('local_bind.txt', 'local bind content') - with pytest.raises(DocServiceError) as exc: - self.manager.upload(UploadRequest( - kb_id='kb_local_bind', - algo_id='wrong_algo', - items=[AddFileItem(file_path=file_path, doc_id='local-bind-doc')], - )) - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_stale_callback_and_state_conflict(self): - self.manager.create_kb('kb_local_stale', algo_id='__default__') - file_path = self._make_file('local_stale.txt', 'local stale content') - uploaded = self.manager.upload(UploadRequest( - kb_id='kb_local_stale', - algo_id='__default__', - items=[AddFileItem(file_path=file_path, doc_id='local-stale-doc')], - )) - self._wait_task(uploaded[0]['task_id'], {'SUCCESS'}) - first_task_id = self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - ))[0] - second_task_id = self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - ))[0] - stale_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-stale-callback', - task_id=first_task_id, - event_type=CallbackEventType.FINISH, - status=DocStatus.SUCCESS, - )) - assert stale_resp['ignored_reason'] == 'stale_task_callback' - self.manager.delete(DeleteRequest(kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'])) - with pytest.raises(DocServiceError) as exc: - self.manager.reparse(ReparseRequest( - kb_id='kb_local_stale', algo_id='__default__', doc_ids=['local-stale-doc'], - )) - assert exc.value.biz_code == 'E_STATE_CONFLICT' - assert second_task_id != first_task_id - - def test_manager_missing_endpoint_surrogates(self): - self.manager.create_kb('kb_local_info', algo_id='__default__') - file_path = self._make_file('local_info.txt', 'local info content') - uploaded = self.manager.upload(UploadRequest( - kb_id='kb_local_info', - algo_id='__default__', - items=[AddFileItem(file_path=file_path, doc_id='local-info-doc')], - )) - algorithms = self.manager.list_algorithms_compat() - assert len(algorithms['items']) >= 1 - algo_info = self.manager.get_algorithm_info('__default__') - assert algo_info['algo_id'] == '__default__' - chunks = self.manager.list_chunks() - assert chunks['items'] == [] - tasks_batch = self.manager.get_tasks_batch([uploaded[0]['task_id']]) - assert len(tasks_batch['items']) == 1 - - def test_delete_kbs_empty_list_rejected(self): - with pytest.raises(DocServiceError) as exc: - self.manager.delete_kbs([]) - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_rejects_unknown_kb_algorithm(self): - with pytest.raises(DocServiceError) as exc: - self.manager.create_kb('kb_local_unknown_algo', algo_id='missing_algo') - assert exc.value.biz_code == 'E_INVALID_PARAM' - - def test_manager_update_kb_can_clear_nullable_fields(self): - self.manager.create_kb( - 'kb_local_clearable', - display_name='Clearable', - description='to be cleared', - owner_id='owner-x', - meta={'tag': 'x'}, - algo_id='__default__', - ) - updated = self.manager.update_kb( - 'kb_local_clearable', - display_name=None, - description=None, - owner_id=None, - meta=None, - explicit_fields={'display_name', 'description', 'owner_id', 'meta'}, - ) - assert updated['display_name'] is None - assert updated['description'] is None - assert updated['owner_id'] is None - assert updated['meta'] == {} - - def test_kb_update_idempotency_payload_distinguishes_omitted_and_null(self): - keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') - clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') - - keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) - clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) - - assert keep_payload != clear_payload - - self.manager.run_idempotent( - '/v1/kbs/kb_local_idem:patch', - 'kb-update-idem', - keep_payload, - lambda: {'kb_id': 'kb_local_idem', 'owner_id': 'kept'}, - ) - with pytest.raises(DocServiceError) as exc: - self.manager.run_idempotent( - '/v1/kbs/kb_local_idem:patch', - 'kb-update-idem', - clear_payload, - lambda: {'kb_id': 'kb_local_idem', 'owner_id': None}, - ) - assert exc.value.biz_code == 'E_IDEMPOTENCY_CONFLICT' - - def test_manager_callback_payload_fallback_and_delete_transition(self): - self.manager.create_kb('kb_local_callback', algo_id='__default__') - file_path = self._make_file('local_callback.txt', 'local callback content') - self.manager._upsert_doc( - doc_id='local-callback-doc', - filename='local_callback.txt', - path=file_path, - metadata={'case': 'callback'}, - source_type=SourceType.EXTERNAL, - ) - self.manager._ensure_kb_document('kb_local_callback', 'local-callback-doc') - queued_at = self.manager._upsert_parse_snapshot( - doc_id='local-callback-doc', - kb_id='kb_local_callback', - algo_id='__default__', - status=DocStatus.DELETING, - task_type=TaskType.DOC_DELETE, - current_task_id='local-delete-task', - queued_at=datetime.now(), - )['queued_at'] - - start_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-delete-start', - task_id='local-delete-task', - event_type=CallbackEventType.START, - status=DocStatus.WORKING, - payload={ - 'task_type': TaskType.DOC_DELETE.value, - 'doc_id': 'local-callback-doc', - 'kb_id': 'kb_local_callback', - 'algo_id': '__default__', - }, - )) - assert start_resp['ack'] is True - start_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') - assert start_snapshot['status'] == DocStatus.DELETING.value - assert start_snapshot['queued_at'] == queued_at - - finish_resp = self.manager.on_task_callback(TaskCallbackRequest( - callback_id='local-delete-finish', - task_id='local-delete-task', - event_type=CallbackEventType.FINISH, - status=DocStatus.SUCCESS, - payload={ - 'task_type': TaskType.DOC_DELETE.value, - 'doc_id': 'local-callback-doc', - 'kb_id': 'kb_local_callback', - 'algo_id': '__default__', - }, - )) - assert finish_resp['ack'] is True - - finish_snapshot = self.manager._get_parse_snapshot('local-callback-doc', 'kb_local_callback', '__default__') - assert finish_snapshot['status'] == DocStatus.DELETED.value - assert self.manager._has_kb_document('kb_local_callback', 'local-callback-doc') is False - assert self.manager._get_doc('local-callback-doc')['upload_status'] == DocStatus.DELETED.value - - def test_parser_client_algo_endpoint_fallback(self): - client = _ParserClient(parser_url='http://parser.test') - calls = [] - - def fake_get(path, params=None): - del params - calls.append(path) - if path == '/v1/algo/list': - raise RuntimeError('parser http error: 404 missing route') - if path == '/algo/list': - return { - 'code': 200, - 'msg': 'success', - 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], - } - if path == '/v1/algo/__default__/groups': - raise RuntimeError('parser http error: 404 missing route') - if path == '/algo/__default__/group/info': - return { - 'code': 200, - 'msg': 'success', - 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], - } - raise AssertionError(path) - - client._get = fake_get - algo_resp = client.list_algorithms() - group_resp = client.get_algorithm_groups('__default__') - - assert algo_resp.code == 200 - assert group_resp.code == 200 - assert calls == [ - '/v1/algo/list', - '/algo/list', - '/v1/algo/__default__/groups', - '/algo/__default__/group/info', - ] diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 1d279cdd1..16cb1bdd9 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -7,15 +7,16 @@ from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.global_metadata import RAG_DOC_PATH, RAG_DOC_ID from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform -import lazyllm.tools.rag.document as document_module -from lazyllm.tools.rag.utils import gen_docid -from lazyllm.launcher import cleanup -from unittest.mock import MagicMock, patch -import unittest import os import shutil -import time import tempfile +import time +import unittest +from unittest.mock import MagicMock + +import lazyllm.tools.rag.document as document_module +from lazyllm.launcher import cleanup +from lazyllm.tools.rag.utils import gen_docid class TestDocImpl(unittest.TestCase): @@ -225,171 +226,6 @@ def test_dataset_path_enables_monitoring_by_default_without_manager(self): assert doc._manager._enable_path_monitoring is True assert doc._impl._dataset_path == doc._manager._origin_path - def test_manager_true_disables_monitoring_and_creates_ui(self): - dataset_path = self._build_dataset() - calls = {} - - class FakeDocumentProcessor: - def __init__(self, *args, **kwargs): - self.url = 'http://127.0.0.1:19001/generate' - - def start(self): - calls['processor_started'] = True - - class FakeDocServer: - def __init__(self, *args, **kwargs): - calls['storage_dir'] = kwargs.get('storage_dir') - self._url = 'http://127.0.0.1:19002/generate' - - class FakeDocWebModule: - def __init__(self, doc_server, *args, **kwargs): - calls['web_doc_server'] = doc_server - self.url = 'http://127.0.0.1:19003' - - def stop(self): - return None - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ - patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ - patch('lazyllm.tools.rag.document.DocWebModule', FakeDocWebModule): - doc = Document(dataset_path, manager=True, create_ui=True) - try: - assert calls['processor_started'] is True - assert calls['storage_dir'] == doc._manager._origin_path - assert calls['web_doc_server'] is doc._manager._manager - assert doc._manager._enable_path_monitoring is False - assert doc._impl._dataset_path == doc._manager._dataset_path - finally: - doc.stop() - - def test_doc_server_manager_disables_monitoring_without_local_path_following(self): - dataset_path = self._build_dataset() - calls = {} - - class FakeDocumentProcessor: - def __init__(self, *args, **kwargs): - calls['parser_url'] = kwargs.get('url') - - def start(self): - return None - - class FakeDocWebModule: - def __init__(self, doc_server, *args, **kwargs): - calls['web_doc_server'] = doc_server - self.url = 'http://127.0.0.1:19003' - - def stop(self): - return None - - class FakeDocServer: - def __init__(self): - self.parser_url = 'http://127.0.0.1:19011/generate' - self._url = 'http://127.0.0.1:19012/generate' - - external_doc_server = FakeDocServer() - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ - patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ - patch('lazyllm.tools.rag.document.DocWebModule', FakeDocWebModule): - doc = Document(dataset_path, manager=external_doc_server, create_ui=True) - try: - assert calls['parser_url'] == external_doc_server.parser_url - assert calls['web_doc_server'] is external_doc_server - assert doc._manager._enable_path_monitoring is False - assert doc._manager._dataset_path == doc._manager._origin_path - assert doc._impl._dataset_path is None - finally: - doc.stop() - - def test_document_processor_manager_requires_store_conf_and_disables_monitoring(self): - class FakeDocumentProcessor: - def __init__(self): - self.start_calls = 0 - - def start(self): - self.start_calls += 1 - - dataset_path = self._build_dataset() - processor = FakeDocumentProcessor() - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor): - with self.assertRaises(ValueError): - Document(dataset_path, manager=processor) - assert processor.start_calls == 0 - - with self.assertRaises(ValueError): - Document(dataset_path, manager=processor, store_conf={'type': 'map'}) - assert processor.start_calls == 0 - - doc = Document(dataset_path, manager=processor, store_conf={'type': 'milvus'}) - - assert processor.start_calls == 1 - assert doc._manager._enable_path_monitoring is False - assert doc._impl._dataset_path == doc._manager._origin_path - - def test_managed_document_keeps_local_dataset_path_for_helper_apis(self): - dataset_path = self._build_dataset() - - class FakeDocumentProcessor: - def __init__(self, *args, **kwargs): - self.url = 'http://127.0.0.1:19001/generate' - - def start(self): - return None - - class FakeDocServer: - def __init__(self, *args, **kwargs): - self._url = 'http://127.0.0.1:19002/generate' - - class FakeGraphRagServerModule: - def __init__(self, kg_dir): - self.kg_dir = kg_dir - - def stop(self): - return None - - expected_files = [os.path.join(dataset_path, 'rag.txt')] - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ - patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ - patch('lazyllm.tools.rag.document.extract_db_schema_from_files', return_value=[]) as extract_mock, \ - patch('lazyllm.tools.rag.graph_document.GraphRagServerModule', FakeGraphRagServerModule): - doc = Document(dataset_path, manager=True) - try: - from lazyllm.tools.rag.graph_document import GraphDocument - - graph_doc = GraphDocument(doc) - assert doc._manager._dataset_path == doc._manager._origin_path - assert doc._impl._dataset_path is None - doc.extract_db_schema(MagicMock()) - extract_mock.assert_called_once_with(expected_files, unittest.mock.ANY) - assert graph_doc._kg_dir == os.path.join(dataset_path, '.graphrag_kg') - finally: - doc.stop() - - def test_remote_doc_server_manager_allows_missing_parser_url(self): - dataset_path = self._build_dataset() - - class FakeDocServer: - def __init__(self, *args, **kwargs): - self._raw_impl = None - self._url = 'http://127.0.0.1:19002/generate' - - @property - def parser_url(self): - return None - - with patch('lazyllm.tools.rag.document.DocServer', FakeDocServer), \ - patch('lazyllm.tools.rag.document.DocumentProcessor') as processor_cls: - doc_server = FakeDocServer() - doc = Document(dataset_path, manager=doc_server) - try: - processor_cls.assert_not_called() - assert doc._manager._dataset_path == doc._manager._origin_path - assert doc._impl._dataset_path is None - finally: - doc.stop() - def test_register_with_pattern(self): Document.create_node_group('AdaptiveChunk1', transform=[ TransformArgs(f=SentenceSplitter, pattern='*.txt', kwargs=dict(chunk_size=512, chunk_overlap=50)), @@ -480,6 +316,7 @@ def test_doc_web_module(self): dataset_path = self._build_dataset() doc = Document(dataset_path, manager=True, create_ui=True) try: + doc.start() doc.create_kb_group(name='test_group') doc2 = Document(dataset_path, manager=doc.manager, name='test_group2') assert hasattr(doc._manager, '_docweb') @@ -492,105 +329,16 @@ def test_manager_ui_remains_compatible(self): dataset_path = self._build_dataset() doc = Document(dataset_path, manager='ui') try: + doc.start() assert hasattr(doc._manager, '_docweb') assert doc._manager._enable_path_monitoring is False finally: doc.stop() - def test_doc_web_module_uses_workspace_pythonpath(self): - dataset_path = self._build_dataset() - calls = {} - - class FakeDocumentProcessor: - def __init__(self, *args, **kwargs): - calls['processor_pythonpath'] = kwargs.get('pythonpath') - self.url = 'http://127.0.0.1:19001/generate' - - def start(self): - calls['processor_started'] = True - - class FakeDocServer: - def __init__(self, *args, **kwargs): - calls['doc_server_pythonpath'] = kwargs.get('pythonpath') - calls['parser_url'] = kwargs.get('parser_url') - self._url = 'http://127.0.0.1:19002/generate' - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ - patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): - doc = Document(dataset_path, manager=True, create_ui=True) - try: - assert calls['processor_started'] is True - assert calls['processor_pythonpath'] == document_module._LOCAL_PYTHONPATH - assert calls['doc_server_pythonpath'] == document_module._LOCAL_PYTHONPATH - assert calls['parser_url'] == 'http://127.0.0.1:19001/generate' - finally: - doc.stop() - - def test_doc_web_module_registers_algorithms_with_spawned_processor(self): - dataset_path = self._build_dataset() - calls = {'registered_algorithms': [], 'add_doc_calls': []} - - class FakeDocumentProcessor: - def __init__(self, *args, **kwargs): - self.url = 'http://127.0.0.1:19001/generate' - - def start(self): - return None - - def register_algorithm(self, name, *args, **kwargs): - calls['registered_algorithms'].append(name) - - def add_doc(self, input_files, ids, metadatas=None, **kwargs): - calls['add_doc_calls'].append({ - 'input_files': input_files, - 'ids': ids, - 'metadatas': metadatas, - }) - - class FakeDocServer: - def __init__(self, *args, **kwargs): - self._url = 'http://127.0.0.1:19002/generate' - - with patch('lazyllm.tools.rag.document.DocumentProcessor', FakeDocumentProcessor), \ - patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): - doc = Document(dataset_path, manager=True, create_ui=True) - try: - doc._impl._lazy_init() - doc2 = doc.create_kb_group(name='test_group') - doc2._impl._lazy_init() - - assert calls['registered_algorithms'] == ['__default__', 'test_group'] - assert len(calls['add_doc_calls']) == 0 - finally: - doc.stop() - def test_create_ui_requires_doc_server(self): with self.assertRaisesRegex(ValueError, 'requires an available DocServer'): Document(self._build_dataset(), create_ui=True) - def test_remote_doc_server_manager_disables_local_path_follow(self): - dataset_path = self._build_dataset() - - class FakeDocServer: - def __init__(self, *args, **kwargs): - self._raw_impl = None - self._url = 'http://127.0.0.1:19002/generate' - - @property - def parser_url(self): - return None - - with patch('lazyllm.tools.rag.document.DocServer', FakeDocServer): - doc_server = FakeDocServer() - doc = Document(dataset_path, manager=doc_server) - try: - assert doc._manager._enable_path_monitoring is False - assert doc._manager._dataset_path == doc._manager._origin_path - assert doc._impl._enable_path_monitoring is False - assert doc._impl._dataset_path is None - finally: - doc.stop() - def test_document_processor_manager_constraints(self): dataset_path = self._build_dataset() processor = document_module.DocumentProcessor(url='http://127.0.0.1:9966') From 4c5ff62f1132e33d0db90a13787740ccdaf813bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Wed, 8 Apr 2026 16:53:30 +0800 Subject: [PATCH 44/46] Fix doc UI and cancel status regressions --- lazyllm/tools/rag/doc_service/doc_manager.py | 2 ++ lazyllm/tools/rag/web.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index dbf02447c..a284464e7 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -834,6 +834,8 @@ def _enqueue_task( def _apply_doc_upload_status(self, doc_id: str, task_type: TaskType, status: DocStatus): if task_type == TaskType.DOC_ADD: + if status in (DocStatus.WORKING, DocStatus.FAILED, DocStatus.CANCELED, DocStatus.SUCCESS): + self._set_doc_upload_status(doc_id, status) return if task_type == TaskType.DOC_DELETE: if status == DocStatus.DELETING: diff --git a/lazyllm/tools/rag/web.py b/lazyllm/tools/rag/web.py index 102828f03..addba74ad 100644 --- a/lazyllm/tools/rag/web.py +++ b/lazyllm/tools/rag/web.py @@ -51,10 +51,13 @@ def delete_group(self, group_name: str): return response.json()['msg'] def list_groups(self): - response = requests.get( - f'{self.base_url}/list_kb_groups', headers=self.basic_headers(False) - ) - return response.json()['data'] + response = requests.get(f'{self.base_url}/list_kb_groups', headers=self.basic_headers(False)) + payload = response.json() + if 'data' in payload: + return payload['data'] + response = requests.get(f'{self.base_url}/v1/kbs', headers=self.basic_headers(False)) + payload = response.json() + return [item.get('kb_id') for item in payload.get('data', {}).get('items', []) if item.get('kb_id')] def upload_files(self, group_name: str, override: bool = True): response = requests.post( From 1f074ba71c0050ec11a3498278186305810d1352 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Thu, 9 Apr 2026 15:00:28 +0800 Subject: [PATCH 45/46] fix: handle transfer tasks and fail on partial rag writes --- lazyllm/tools/rag/parsing_service/impl.py | 19 +++++ lazyllm/tools/rag/parsing_service/server.py | 27 ++---- lazyllm/tools/rag/store/document_store.py | 9 +- .../RAG/test_doc_processor_transfer.py | 31 +++++++ .../RAG/test_document_store_upsert_failure.py | 40 +++++++++ .../basic_tests/RAG/test_processor_cleanup.py | 83 +++++++++++++++++++ 6 files changed, 187 insertions(+), 22 deletions(-) create mode 100644 tests/basic_tests/RAG/test_doc_processor_transfer.py create mode 100644 tests/basic_tests/RAG/test_document_store_upsert_failure.py create mode 100644 tests/basic_tests/RAG/test_processor_cleanup.py diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index 6c52deae8..d80a6a73f 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -109,6 +109,7 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no transfer_mode: Optional[str] = None, target_kb_id: Optional[str] = None, target_doc_ids: Optional[List[str]] = None, preloaded_root_nodes: Optional[Dict[str, List[DocNode]]] = None): + ids = ids or [] try: if not input_files: return add_start = time.time() @@ -180,9 +181,27 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no LOG.info(f'[_Processor - add_doc] Add documents done! files:{input_files}, ' f'Total Time: {add_time}s, Data Loading Time: {load_time}s') except Exception as e: + cleanup_doc_ids = ids if transfer_mode is None else (target_doc_ids or []) + cleanup_kb_id = kb_id if transfer_mode is None else target_kb_id + self._cleanup_failed_add(cleanup_doc_ids, cleanup_kb_id, clear_schema=transfer_mode is None) LOG.error(f'Add documents failed: {e}, {traceback.format_exc()}') raise e + def _cleanup_failed_add(self, doc_ids: List[str], kb_id: Optional[str], clear_schema: bool) -> None: + if not doc_ids: + return + try: + self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id) + except Exception as cleanup_exc: + LOG.error(f'Failed to cleanup nodes for docs {doc_ids} in kb {kb_id}: {cleanup_exc}, ' + f'{traceback.format_exc()}') + if clear_schema and self._schema_extractor: + try: + self._schema_extractor._delete_extract_data(algo_id=self._algo_id, kb_id=kb_id, doc_ids=doc_ids) + except Exception as cleanup_exc: + LOG.error(f'Failed to cleanup schema data for docs {doc_ids} in kb {kb_id}: {cleanup_exc}, ' + f'{traceback.format_exc()}') + def close(self): self._thread_pool.shutdown(wait=True) self._thread_pool = None diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index df2b7cb35..f01ad5317 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -19,7 +19,7 @@ from .base import ( ALGORITHM_TABLE_INFO, WAITING_TASK_QUEUE_TABLE_INFO, FINISHED_TASK_QUEUE_TABLE_INFO, TaskStatus, TaskType, UpdateMetaRequest, AddDocRequest, CancelTaskRequest, DeleteDocRequest, - _calculate_task_score + _calculate_task_score, _resolve_add_doc_task_type ) from .worker import DocumentProcessorWorker as Worker from .queue import _SQLBasedQueue as Queue @@ -561,24 +561,11 @@ def list_doc_chunks( return BaseResponse(code=200, msg='success', data=data) @staticmethod - def _resolve_add_task_type(file_infos) -> str: - has_reparse = False - has_new_file = False - for file_info in file_infos: - if file_info.reparse_group is not None: - has_reparse = True - else: - has_new_file = True - if has_new_file and has_reparse: - raise fastapi.HTTPException( - status_code=400, - detail='new_file_ids and reparse_file_ids cannot be specified at the same time' - ) - if has_reparse: - return TaskType.DOC_REPARSE.value - if has_new_file: - return TaskType.DOC_ADD.value - raise fastapi.HTTPException(status_code=400, detail='no input files or reparse group specified') + def _resolve_add_task_type(request: AddDocRequest) -> str: + try: + return _resolve_add_doc_task_type(request) + except ValueError as exc: + raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc @app.post('/doc/add') def add_doc(self, request: AddDocRequest): # noqa: C901 @@ -598,7 +585,7 @@ def add_doc(self, request: AddDocRequest): # noqa: C901 for file_info in file_infos: if self._path_prefix: file_info.file_path = create_file_path(path=file_info.file_path, prefix=self._path_prefix) - task_type = self._resolve_add_task_type(file_infos) + task_type = self._resolve_add_task_type(request) payload = request.model_dump() resolved_callback_url = self._resolve_callback_url(payload) if resolved_callback_url: diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index e6460e0a6..b84b89102 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -158,6 +158,11 @@ def is_group_active(self, group: str) -> bool: def is_group_empty(self, group: str) -> bool: return not self.impl.get(self._gen_collection_name(group), {}, limit=10) + def _upsert_segments(self, group: str, segments: List[dict]) -> None: + collection_name = self._gen_collection_name(group) + if not self.impl.upsert(collection_name, segments): + raise RuntimeError(f'[_DocumentStore - {self._algo_name}] Failed to upsert segments for group {group}') + def update_nodes(self, nodes: List[DocNode], copy: bool = False): # noqa: C901 if not nodes: return @@ -176,7 +181,7 @@ def update_nodes(self, nodes: List[DocNode], copy: bool = False): # noqa: C901 LOG.warning(f'[_DocumentStore - {self._algo_name}] Group {group} is not active, skip') continue for i in range(0, len(segments), INSERT_BATCH_SIZE): - self.impl.upsert(self._gen_collection_name(group), segments[i:i + INSERT_BATCH_SIZE]) + self._upsert_segments(group, segments[i:i + INSERT_BATCH_SIZE]) # update indices for index in self._indices.values(): index.update(nodes) @@ -288,7 +293,7 @@ def update_doc_meta(self, doc_id: str, metadata: dict, kb_id: str = None) -> Non segment['global_meta'].update(metadata) group_segments[segment.get('group')].append(segment) for group, segments in group_segments.items(): - self.impl.upsert(self._gen_collection_name(group), segments) + self._upsert_segments(group, segments) LOG.info(f'[_DocumentStore] Updated metadata for doc_id: {doc_id} in dataset: {kb_id}') return diff --git a/tests/basic_tests/RAG/test_doc_processor_transfer.py b/tests/basic_tests/RAG/test_doc_processor_transfer.py new file mode 100644 index 000000000..73adbed3c --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_processor_transfer.py @@ -0,0 +1,31 @@ +from lazyllm.tools.rag.parsing_service import DocumentProcessor +from lazyllm.tools.rag.parsing_service.base import AddDocRequest, FileInfo, TaskType, TransferParams + + +def test_resolve_add_task_type_returns_transfer_for_copy_request(): + request = AddDocRequest( + algo_id='general_algo', + kb_id='kb_source', + file_infos=[FileInfo( + file_path='/tmp/source.pdf', + doc_id='doc_source', + transfer_params=TransferParams( + mode='cp', + target_algo_id='general_algo', + target_doc_id='doc_target', + target_kb_id='kb_target', + ), + )], + ) + + assert DocumentProcessor._Impl._resolve_add_task_type(request) == TaskType.DOC_TRANSFER.value + + +def test_resolve_add_task_type_keeps_add_for_regular_upload(): + request = AddDocRequest( + algo_id='general_algo', + kb_id='kb_source', + file_infos=[FileInfo(file_path='/tmp/source.pdf', doc_id='doc_source')], + ) + + assert DocumentProcessor._Impl._resolve_add_task_type(request) == TaskType.DOC_ADD.value diff --git a/tests/basic_tests/RAG/test_document_store_upsert_failure.py b/tests/basic_tests/RAG/test_document_store_upsert_failure.py new file mode 100644 index 000000000..369ce4aed --- /dev/null +++ b/tests/basic_tests/RAG/test_document_store_upsert_failure.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock + +import pytest + +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.global_metadata import RAG_DOC_ID, RAG_KB_ID +from lazyllm.tools.rag.store.document_store import _DocumentStore + + +def test_update_nodes_raises_when_store_upsert_returns_false(): + document_store = _DocumentStore(algo_name='test_algo', store={'type': 'map'}) + document_store.activate_group('group1') + document_store.impl.upsert = MagicMock(return_value=False) + + node = DocNode( + uid='node1', + text='text1', + group='group1', + global_metadata={RAG_KB_ID: 'kb1', RAG_DOC_ID: 'doc1'}, + ) + + with pytest.raises(RuntimeError, match='Failed to upsert segments for group group1'): + document_store.update_nodes([node], copy=True) + + +def test_update_doc_meta_raises_when_store_upsert_returns_false(): + document_store = _DocumentStore(algo_name='test_algo', store={'type': 'map'}) + document_store.activate_group('group1') + + node = DocNode( + uid='node1', + text='text1', + group='group1', + global_metadata={RAG_KB_ID: 'kb1', RAG_DOC_ID: 'doc1'}, + ) + document_store.update_nodes([node], copy=True) + document_store.impl.upsert = MagicMock(return_value=False) + + with pytest.raises(RuntimeError, match='Failed to upsert segments for group group1'): + document_store.update_doc_meta(doc_id='doc1', metadata={'foo': 'bar'}, kb_id='kb1') diff --git a/tests/basic_tests/RAG/test_processor_cleanup.py b/tests/basic_tests/RAG/test_processor_cleanup.py new file mode 100644 index 000000000..e42d4fba8 --- /dev/null +++ b/tests/basic_tests/RAG/test_processor_cleanup.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.global_metadata import RAG_DOC_ID, RAG_KB_ID +from lazyllm.tools.rag.parsing_service import _Processor +from lazyllm.tools.rag.store import LAZY_IMAGE_GROUP, LAZY_ROOT_NAME + + +def _make_root_node(doc_id: str, kb_id: str) -> DocNode: + return DocNode( + uid=f'{doc_id}-root', + text='root', + group=LAZY_ROOT_NAME, + global_metadata={RAG_DOC_ID: doc_id, RAG_KB_ID: kb_id}, + ) + + +def test_add_doc_cleans_partial_segments_and_schema_on_failure(): + store = MagicMock() + reader = MagicMock() + schema_extractor = MagicMock() + reader.load_data.return_value = { + LAZY_ROOT_NAME: [_make_root_node('doc1', 'kb1')], + LAZY_IMAGE_GROUP: [], + } + store.update_nodes.side_effect = RuntimeError('upsert failed') + + processor = _Processor( + algo_id='algo1', + store=store, + reader=reader, + node_groups={}, + schema_extractor=schema_extractor, + ) + try: + with pytest.raises(RuntimeError, match='upsert failed'): + processor.add_doc( + input_files=['/tmp/doc1.txt'], + ids=['doc1'], + metadatas=[{}], + kb_id='kb1', + ) + finally: + processor.close() + + store.remove_nodes.assert_called_once_with(doc_ids=['doc1'], kb_id='kb1') + schema_extractor._delete_extract_data.assert_called_once_with( + algo_id='algo1', + kb_id='kb1', + doc_ids=['doc1'], + ) + + +def test_transfer_failure_cleans_target_segments_only(): + store = MagicMock() + reader = MagicMock() + store.get_nodes.return_value = [_make_root_node('source-doc', 'source-kb')] + store.update_nodes.side_effect = RuntimeError('upsert failed') + + processor = _Processor( + algo_id='algo1', + store=store, + reader=reader, + node_groups={}, + ) + try: + with pytest.raises(RuntimeError, match='upsert failed'): + processor.add_doc( + input_files=['/tmp/source.txt'], + ids=['source-doc'], + metadatas=[{}], + kb_id='source-kb', + transfer_mode='cp', + target_kb_id='target-kb', + target_doc_ids=['target-doc'], + ) + finally: + processor.close() + + store.remove_nodes.assert_called_once_with(doc_ids=['target-doc'], kb_id='target-kb') + reader.load_data.assert_not_called() From 8993fb64ada46e9bc6f784d41424de012bc78fce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=99=88=E5=AE=B6=E8=B1=AA?= Date: Fri, 10 Apr 2026 12:50:33 +0800 Subject: [PATCH 46/46] fix(rag): address doc_service review feedback --- examples/rag/doc_service_standalone.py | 9 +- lazyllm/common/utils.py | 8 +- lazyllm/components/deploy/relay/server.py | 8 +- lazyllm/tools/rag/__init__.py | 3 +- lazyllm/tools/rag/doc_service/__init__.py | 53 ++- lazyllm/tools/rag/doc_service/base.py | 23 +- lazyllm/tools/rag/doc_service/doc_manager.py | 313 ++++++++++++------ .../RAG/test_doc_service_doc_manager.py | 77 +++++ 8 files changed, 375 insertions(+), 119 deletions(-) diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py index 8d98d1f24..1913baaf9 100644 --- a/examples/rag/doc_service_standalone.py +++ b/examples/rag/doc_service_standalone.py @@ -11,6 +11,7 @@ import argparse import os import time +from pathlib import Path from typing import Any, Dict import requests @@ -59,8 +60,8 @@ def _wait_http_ok(url: str, timeout: float = 20.0): def _build_store_conf(root_dir: str) -> Dict[str, Any]: segment_store_path = os.path.join(root_dir, 'segments.db') milvus_store_path = os.path.join(root_dir, 'milvus_lite.db') - open(segment_store_path, 'a', encoding='utf-8').close() - open(milvus_store_path, 'a', encoding='utf-8').close() + Path(segment_store_path).touch() + Path(milvus_store_path).touch() return { 'segment_store': {'type': 'map', 'kwargs': {'uri': segment_store_path}}, 'vector_store': { @@ -149,6 +150,7 @@ def main(): if parser_url: parser_url = parser_url.rstrip('/') _wait_http_ok(f'{parser_url}/health') + _wait_algo_ready(parser_url, args.algo_id) else: parser_server, parser_url = _start_local_parser(args.parser_port, paths['parser_db']) document = _register_demo_algorithm(parser_url, args.algo_id, paths['store_dir']) @@ -169,8 +171,7 @@ def main(): print(f'Parser URL: {parser_url}', flush=True) print(f'Storage Dir: {paths["storage_dir"]}', flush=True) print(f'Doc DB: {paths["doc_db"]}', flush=True) - if document: - print(f'Algorithm ID: {args.algo_id}', flush=True) + print(f'Algorithm ID: {args.algo_id}', flush=True) if args.wait: # Step 4: keep the services alive for manual API testing. diff --git a/lazyllm/common/utils.py b/lazyllm/common/utils.py index 68745f06e..3dfa5e1f0 100644 --- a/lazyllm/common/utils.py +++ b/lazyllm/common/utils.py @@ -185,12 +185,14 @@ def _collect_test_modules(obj): def env_helper(): os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'ON' modules = _collect_test_modules(f) - for module in modules: - cloudpickle.register_pickle_by_value(module) + registered_modules = [] try: + for module in modules: + cloudpickle.register_pickle_by_value(module) + registered_modules.append(module) yield finally: - for module in modules: + for module in reversed(registered_modules): cloudpickle.unregister_pickle_by_value(module) os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'OFF' diff --git a/lazyllm/components/deploy/relay/server.py b/lazyllm/components/deploy/relay/server.py index 51e713562..2841e5172 100644 --- a/lazyllm/components/deploy/relay/server.py +++ b/lazyllm/components/deploy/relay/server.py @@ -104,9 +104,9 @@ async def wrapper(request: fastapi.Request): @security_check async def lazyllm_call(request: fastapi.Request): try: - fname, args, kwargs = await request.json() - args, kwargs = load_obj(args), load_obj(kwargs) - r = await async_wrapper(getattr(func, fname), *args, **kwargs) + fname, call_args, call_kwargs = await request.json() + call_args, call_kwargs = load_obj(call_args), load_obj(call_kwargs) + r = await async_wrapper(getattr(func, fname), *call_args, **call_kwargs) return fastapi.responses.Response(content=codecs.encode(pickle.dumps(r), 'base64')) except requests.RequestException as e: return fastapi.responses.Response(content=f'{str(e)}', status_code=500) @@ -156,7 +156,7 @@ def impl(o): def generate_stream(): for o in output: yield impl(o) - return fastapi.responses.StreamingResponse(generate_stream(), media_type='text_plain') + return fastapi.responses.StreamingResponse(generate_stream(), media_type='text/plain') elif args.after_function: assert (callable(after_func)), 'after_func must be callable' r = inspect.getfullargspec(after_func) diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index 03f3039df..ccf80bf9a 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -4,7 +4,7 @@ # flake8: noqa: E402 from .document import Document -from .doc_manager import DocManager +from .doc_manager import DocManager, DocListManager from .doc_service import DocServer from .graph_document import GraphDocument, UrlGraphDocument from .retriever import Retriever, TempDocRetriever, ContextRetriever, WeightedRetriever, PriorityRetriever @@ -32,6 +32,7 @@ 'add_post_action_for_default_reader', 'Document', 'DocManager', + 'DocListManager', 'DocServer', 'GraphDocument', 'UrlGraphDocument', diff --git a/lazyllm/tools/rag/doc_service/__init__.py b/lazyllm/tools/rag/doc_service/__init__.py index 9f3ef78e9..46d8f0cca 100644 --- a/lazyllm/tools/rag/doc_service/__init__.py +++ b/lazyllm/tools/rag/doc_service/__init__.py @@ -1,4 +1,55 @@ +from .base import ( + AddFileItem, + AddRequest, + AlgorithmInfoRequest, + CallbackEventType, + DeleteRequest, + DocServiceError, + DocStatus, + KBStatus, + KbBatchQueryRequest, + KbCreateRequest, + KbDeleteBatchRequest, + KbUpdateRequest, + MetadataPatchItem, + MetadataPatchRequest, + ReparseRequest, + SourceType, + TaskBatchRequest, + TaskCallbackRequest, + TaskCancelRequest, + TaskInfoRequest, + TransferItem, + TransferRequest, + UploadRequest, +) from .doc_server import DocServer from .doc_manager import DocManager -__all__ = ['DocServer', 'DocManager'] +__all__ = [ + 'AddFileItem', + 'AddRequest', + 'AlgorithmInfoRequest', + 'CallbackEventType', + 'DeleteRequest', + 'DocManager', + 'DocServer', + 'DocServiceError', + 'DocStatus', + 'KBStatus', + 'KbBatchQueryRequest', + 'KbCreateRequest', + 'KbDeleteBatchRequest', + 'KbUpdateRequest', + 'MetadataPatchItem', + 'MetadataPatchRequest', + 'ReparseRequest', + 'SourceType', + 'TaskBatchRequest', + 'TaskCallbackRequest', + 'TaskCancelRequest', + 'TaskInfoRequest', + 'TransferItem', + 'TransferRequest', + 'UploadRequest', +] diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py index c3e80fc0e..28395cbf3 100644 --- a/lazyllm/tools/rag/doc_service/base.py +++ b/lazyllm/tools/rag/doc_service/base.py @@ -2,7 +2,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, TypeAlias from uuid import uuid4 from pydantic import AliasChoices, BaseModel, Field, model_validator @@ -43,6 +43,7 @@ class CallbackEventType(str, Enum): 'E_STATE_CONFLICT': 409, 'E_IDEMPOTENCY_CONFLICT': 409, 'E_IDEMPOTENCY_IN_PROGRESS': 409, + 'E_UPSTREAM_ERROR': 502, } @@ -94,8 +95,8 @@ def validate_items(self): return self -AddRequest = DocItemsRequest -UploadRequest = DocItemsRequest +AddRequest: TypeAlias = DocItemsRequest +UploadRequest: TypeAlias = DocItemsRequest class _DocMutationRequest(BaseModel): @@ -240,8 +241,8 @@ class KbRequest(BaseModel): idempotency_key: Optional[str] = None -KbCreateRequest = KbRequest -KbUpdateRequest = KbRequest +KbCreateRequest: TypeAlias = KbRequest +KbUpdateRequest: TypeAlias = KbRequest class KbBatchQueryRequest(BaseModel): @@ -334,6 +335,18 @@ def validate_kb_ids(self): } +DOC_PATH_LOCKS_TABLE_INFO = { + 'name': 'lazyllm_doc_path_locks', + 'comment': 'Transient lock table for serializing document path writes', + 'columns': [ + {'name': 'path', 'data_type': 'string', 'nullable': False, 'is_primary_key': True, + 'comment': 'Absolute file path'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + ], +} + + KBS_TABLE_INFO = { 'name': 'lazyllm_knowledge_bases', 'comment': 'Knowledge base table', diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py index a284464e7..0af64be32 100644 --- a/lazyllm/tools/rag/doc_service/doc_manager.py +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -10,16 +10,15 @@ from sqlalchemy.exc import IntegrityError from lazyllm import LOG -from lazyllm.thirdparty import fastapi from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict from ...sql import SqlManager from .base import ( AddRequest, CALLBACK_RECORDS_TABLE_INFO, CallbackEventType, DOC_SERVICE_TASKS_TABLE_INFO, - DOCUMENTS_TABLE_INFO, DeleteRequest, DocServiceError, DocStatus, IDEMPOTENCY_RECORDS_TABLE_INFO, - KB_ALGORITHM_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KBStatus, MetadataPatchRequest, - PARSE_STATE_TABLE_INFO, ReparseRequest, SourceType, TaskCallbackRequest, TaskType, TransferRequest, - UploadRequest, + DOC_PATH_LOCKS_TABLE_INFO, DOCUMENTS_TABLE_INFO, DeleteRequest, DocServiceError, DocStatus, + IDEMPOTENCY_RECORDS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, + KBStatus, MetadataPatchRequest, PARSE_STATE_TABLE_INFO, ReparseRequest, SourceType, + TaskCallbackRequest, TaskType, TransferRequest, UploadRequest, ) from .parser_client import ParserClient from .utils import ( @@ -43,15 +42,18 @@ def __init__( **self._db_config, tables_info_dict={ 'tables': [ - DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, - PARSE_STATE_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO, - DOC_SERVICE_TASKS_TABLE_INFO, + DOC_PATH_LOCKS_TABLE_INFO, DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, + KB_ALGORITHM_TABLE_INFO, PARSE_STATE_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, + CALLBACK_RECORDS_TABLE_INFO, DOC_SERVICE_TASKS_TABLE_INFO, ] }, ) self._ensure_indexes() self._parser_client = ParserClient(parser_url=parser_url) - self._parser_client.health() + try: + self._parser_client.health() + except Exception as exc: + raise RuntimeError(f'parser service is unavailable: {parser_url}') from exc self._callback_url = callback_url def set_callback_url(self, callback_url: str): @@ -197,32 +199,49 @@ def _ensure_algorithm_exists(self, algo_id: str): return raise DocServiceError('E_INVALID_PARAM', f'invalid algo_id: {algo_id}', {'algo_id': algo_id}) - def _ensure_kb_document(self, kb_id: str, doc_id: str): + def _refresh_kb_doc_count_in_session(self, session, kb_id: str): + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is None: + return + kb_row.doc_count = session.query(Rel).filter(Rel.kb_id == kb_id).count() + if kb_row.status == KBStatus.DELETING.value and kb_row.doc_count == 0: + kb_row.status = KBStatus.DELETED.value + kb_row.updated_at = datetime.now() + session.add(kb_row) + + def _ensure_kb_document_in_session(self, session, kb_id: str, doc_id: str): now = datetime.now() - created = False + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is None: + session.add(Rel(kb_id=kb_id, doc_id=doc_id, created_at=now, updated_at=now)) + return True + row.updated_at = now + session.add(row) + return False + + def _ensure_kb_document(self, kb_id: str, doc_id: str): with self._db_manager.get_session() as session: - Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) - row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() - if row is None: - created = True - row = Rel(kb_id=kb_id, doc_id=doc_id, created_at=now, updated_at=now) - else: - row.updated_at = now - session.add(row) - if created: - self._refresh_kb_doc_count(kb_id) + created = self._ensure_kb_document_in_session(session, kb_id, doc_id) + if created: + self._refresh_kb_doc_count_in_session(session, kb_id) return created + def _remove_kb_document_in_session(self, session, kb_id: str, doc_id: str): + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is None: + return False + session.delete(row) + return True + def _remove_kb_document(self, kb_id: str, doc_id: str): - removed = False with self._db_manager.get_session() as session: - Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) - row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() - if row is not None: - session.delete(row) - removed = True - if removed: - self._refresh_kb_doc_count(kb_id) + removed = self._remove_kb_document_in_session(session, kb_id, doc_id) + if removed: + self._refresh_kb_doc_count_in_session(session, kb_id) return removed def _load_idempotency_record(self, endpoint: str, idempotency_key: str): @@ -331,7 +350,7 @@ def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb now = datetime.now() with self._db_manager.get_session() as session: Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) - session.add(Task( + row = Task( task_id=task_id, task_type=task_type.value, doc_id=doc_id, @@ -345,8 +364,12 @@ def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb updated_at=now, started_at=None, finished_at=None, - )) - return self._get_task_record(task_id) + ) + session.add(row) + session.flush() + task = _orm_to_dict(row) + task['message'] = from_json(task.get('message')) + return task def _get_task_record(self, task_id: str): with self._db_manager.get_session() as session: @@ -368,20 +391,14 @@ def _update_task_record(self, task_id: str, **fields): setattr(row, key, value) row.updated_at = datetime.now() session.add(row) - return self._get_task_record(task_id) + session.flush() + task = _orm_to_dict(row) + task['message'] = from_json(task.get('message')) + return task def _refresh_kb_doc_count(self, kb_id: str): with self._db_manager.get_session() as session: - Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) - Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) - kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() - if kb_row is None: - return - kb_row.doc_count = session.query(Rel).filter(Rel.kb_id == kb_id).count() - if kb_row.status == KBStatus.DELETING.value and kb_row.doc_count == 0: - kb_row.status = KBStatus.DELETED.value - kb_row.updated_at = datetime.now() - session.add(kb_row) + self._refresh_kb_doc_count_in_session(session, kb_id) def _list_kb_doc_ids(self, kb_id: str) -> List[str]: with self._db_manager.get_session() as session: @@ -411,8 +428,9 @@ def _get_doc_by_path(self, path: str): row = session.query(Doc).filter(Doc.path == path).first() return _orm_to_dict(row) if row else None - def _upsert_doc( + def _upsert_doc_in_session( self, + session, doc_id: str, filename: str, path: str, @@ -426,48 +444,118 @@ def _upsert_doc( size_bytes = os.path.getsize(path) if os.path.exists(path) else None content_hash = sha256_file(path) if os.path.exists(path) else None allowed_path_doc_ids = allowed_path_doc_ids or set() + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + PathLock = self._db_manager.get_table_orm_class(DOC_PATH_LOCKS_TABLE_INFO['name']) + session.add(PathLock(path=path, created_at=now)) + session.flush() + try: + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + path_rows = session.query(Doc).filter(Doc.path == path).all() + conflict = next( + ( + item for item in path_rows + if item.doc_id != doc_id and item.doc_id not in allowed_path_doc_ids + ), + None, + ) + if conflict is not None: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': conflict.doc_id, 'path': path}, + ) + if row is None: + row = Doc( + doc_id=doc_id, + filename=filename, + path=path, + meta=to_json(metadata), + upload_status=upload_status.value, + source_type=source_type.value, + file_type=file_type, + content_hash=content_hash, + size_bytes=size_bytes, + created_at=now, + updated_at=now, + ) + else: + row.filename = filename + row.path = path + row.meta = to_json(metadata) + row.upload_status = upload_status.value + row.source_type = source_type.value + row.file_type = file_type + row.content_hash = content_hash + row.size_bytes = size_bytes + row.updated_at = now + session.add(row) + session.flush() + return row + finally: + try: + session.query(PathLock).filter(PathLock.path == path).delete() + session.flush() + except Exception: + pass + + def _upsert_doc( + self, + doc_id: str, + filename: str, + path: str, + metadata: Dict[str, Any], + source_type: SourceType, + upload_status: DocStatus = DocStatus.SUCCESS, + allowed_path_doc_ids: Optional[Set[str]] = None, + ): try: with self._db_manager.get_session() as session: - Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) - row = session.query(Doc).filter(Doc.doc_id == doc_id).first() - path_row = session.query(Doc).filter(Doc.path == path).first() - if path_row is not None and path_row.doc_id != doc_id and path_row.doc_id not in allowed_path_doc_ids: - raise DocServiceError( - 'E_STATE_CONFLICT', - f'doc path already exists: {path}', - {'doc_id': path_row.doc_id, 'path': path}, - ) - if row is None: - row = Doc( - doc_id=doc_id, - filename=filename, - path=path, - meta=to_json(metadata), - upload_status=upload_status.value, - source_type=source_type.value, - file_type=file_type, - content_hash=content_hash, - size_bytes=size_bytes, - created_at=now, - updated_at=now, - ) - else: - row.filename = filename - row.path = path - row.meta = to_json(metadata) - row.upload_status = upload_status.value - row.source_type = source_type.value - row.file_type = file_type - row.content_hash = content_hash - row.size_bytes = size_bytes - row.updated_at = now - session.add(row) + row = self._upsert_doc_in_session( + session, doc_id, filename, path, metadata, source_type, + upload_status=upload_status, allowed_path_doc_ids=allowed_path_doc_ids, + ) + return _orm_to_dict(row) except IntegrityError as exc: existing = self._get_doc_by_path(path) if ( existing is not None and existing.get('doc_id') != doc_id - and existing.get('doc_id') not in allowed_path_doc_ids + and existing.get('doc_id') not in (allowed_path_doc_ids or set()) + ): + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': existing['doc_id'], 'path': path}, + ) from exc + raise + + def _upsert_doc_and_bind( + self, + kb_id: str, + doc_id: str, + filename: str, + path: str, + metadata: Dict[str, Any], + source_type: SourceType, + upload_status: DocStatus = DocStatus.SUCCESS, + allowed_path_doc_ids: Optional[Set[str]] = None, + ): + try: + with self._db_manager.get_session() as session: + row = self._upsert_doc_in_session( + session, doc_id, filename, path, metadata, source_type, + upload_status=upload_status, allowed_path_doc_ids=allowed_path_doc_ids, + ) + created = self._ensure_kb_document_in_session(session, kb_id, doc_id) + if created: + self._refresh_kb_doc_count_in_session(session, kb_id) + return _orm_to_dict(row) + except IntegrityError as exc: + existing = self._get_doc_by_path(path) + if ( + existing is not None + and existing.get('doc_id') != doc_id + and existing.get('doc_id') not in (allowed_path_doc_ids or set()) ): raise DocServiceError( 'E_STATE_CONFLICT', @@ -475,7 +563,6 @@ def _upsert_doc( {'doc_id': existing['doc_id'], 'path': path}, ) from exc raise - return self._get_doc(doc_id) def _set_doc_upload_status(self, doc_id: str, status: DocStatus): with self._db_manager.get_session() as session: @@ -514,22 +601,38 @@ def _delete_parse_snapshots(self, doc_id: str, kb_id: str): State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) session.query(State).filter(State.doc_id == doc_id, State.kb_id == kb_id).delete() - def _delete_doc_if_orphaned(self, doc_id: str) -> bool: - if self._doc_relation_count(doc_id) > 0: + def _delete_doc_if_orphaned_in_session(self, session, doc_id: str) -> bool: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + session.flush() + if session.query(Rel).filter(Rel.doc_id == doc_id).count() > 0: return False - with self._db_manager.get_session() as session: - Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) - row = session.query(Doc).filter(Doc.doc_id == doc_id).first() - if row is None: - return False - session.delete(row) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return False + session.delete(row) return True + def _delete_doc_if_orphaned(self, doc_id: str) -> bool: + with self._db_manager.get_session() as session: + return self._delete_doc_if_orphaned_in_session(session, doc_id) + + def _remove_kb_document_and_delete_orphan(self, kb_id: str, doc_id: str) -> bool: + with self._db_manager.get_session() as session: + removed = self._remove_kb_document_in_session(session, kb_id, doc_id) + doc_deleted = self._delete_doc_if_orphaned_in_session(session, doc_id) + if removed: + self._refresh_kb_doc_count_in_session(session, kb_id) + return doc_deleted + def _purge_deleted_kb_doc_data(self, kb_id: str, doc_id: str, remove_relation: bool = False): + doc_deleted = False if remove_relation: - self._remove_kb_document(kb_id, doc_id) + doc_deleted = self._remove_kb_document_and_delete_orphan(kb_id, doc_id) + else: + doc_deleted = self._delete_doc_if_orphaned(doc_id) self._delete_parse_snapshots(doc_id, kb_id) - if not self._delete_doc_if_orphaned(doc_id): + if not doc_deleted: self._sync_doc_upload_status(doc_id) def _mark_task_cleanup_policy(self, task_id: str, cleanup_policy: str): @@ -1030,7 +1133,8 @@ def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: doc_id = item['doc_id'] file_path = item['file_path'] metadata = item['metadata'] - doc = self._upsert_doc( + doc = self._upsert_doc_and_bind( + kb_id=request.kb_id, doc_id=doc_id, filename=item['filename'], path=file_path, @@ -1038,7 +1142,6 @@ def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: source_type=source_type, upload_status=DocStatus.SUCCESS, ) - self._ensure_kb_document(request.kb_id, doc_id) try: task_id, snapshot = self._enqueue_task( doc_id, request.kb_id, request.algo_id, TaskType.DOC_ADD, @@ -1163,7 +1266,8 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: for item in prepared_items: task_id = None try: - self._upsert_doc( + self._upsert_doc_and_bind( + kb_id=item['target_kb_id'], doc_id=item['target_doc_id'], filename=item['target_filename'], path=item['target_file_path'], @@ -1172,7 +1276,6 @@ def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: upload_status=DocStatus.SUCCESS, allowed_path_doc_ids={item['doc_id']}, ) - self._ensure_kb_document(item['target_kb_id'], item['target_doc_id']) task_id, snapshot = self._enqueue_task( item['target_doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, idempotency_key=request.idempotency_key, @@ -1390,11 +1493,16 @@ def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 ), ) + upload_status_handled = False if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: - self._remove_kb_document(kb_id, doc_id) - self._apply_doc_upload_status(doc_id, task_type, final_status) + if cleanup_policy == 'purge': + self._purge_deleted_kb_doc_data(kb_id, doc_id, remove_relation=True) + upload_status_handled = True + else: + self._remove_kb_document(kb_id, doc_id) + if not upload_status_handled: + self._apply_doc_upload_status(doc_id, task_type, final_status) if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED and cleanup_policy == 'purge': - self._purge_deleted_kb_doc_data(kb_id, doc_id) self._finalize_kb_deletion_if_empty(kb_id) return {'ack': True, 'deduped': False, 'ignored_reason': None} @@ -1557,15 +1665,15 @@ def cancel_task(self, task_id: str): def list_algorithms(self): resp = self._parser_client.list_algorithms() if resp.code != 200: - raise fastapi.HTTPException(status_code=502, detail=resp.msg) + raise DocServiceError('E_UPSTREAM_ERROR', resp.msg, {'upstream_status': resp.code}) return resp.data def get_algo_groups(self, algo_id: str): resp = self._parser_client.get_algorithm_groups(algo_id) if resp.code == 404: - raise fastapi.HTTPException(status_code=404, detail='algo not found') + raise DocServiceError('E_NOT_FOUND', 'algo not found', {'algo_id': algo_id}) if resp.code != 200: - raise fastapi.HTTPException(status_code=502, detail=resp.msg) + raise DocServiceError('E_UPSTREAM_ERROR', resp.msg, {'algo_id': algo_id, 'upstream_status': resp.code}) return resp.data def list_algorithms_compat(self): @@ -1624,7 +1732,10 @@ def list_chunks( if resp.code == 400: raise DocServiceError('E_INVALID_PARAM', resp.msg, {'kb_id': kb_id, 'doc_id': doc_id, 'group': group}) if resp.code != 200: - raise fastapi.HTTPException(status_code=502, detail=resp.msg) + raise DocServiceError( + 'E_UPSTREAM_ERROR', resp.msg, + {'kb_id': kb_id, 'doc_id': doc_id, 'group': group, 'upstream_status': resp.code} + ) data = dict(resp.data or {}) data['page'] = page data['page_size'] = page_size diff --git a/tests/basic_tests/RAG/test_doc_service_doc_manager.py b/tests/basic_tests/RAG/test_doc_service_doc_manager.py index 9ace9c6be..92b054324 100644 --- a/tests/basic_tests/RAG/test_doc_service_doc_manager.py +++ b/tests/basic_tests/RAG/test_doc_service_doc_manager.py @@ -1,5 +1,6 @@ import os import tempfile +import threading import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime @@ -171,6 +172,42 @@ def handler(): assert replay == result +def test_manager_upsert_same_path_is_serialized(manager_harness): + file_path = manager_harness.make_file('serialized.txt', 'serialized content') + barrier = threading.Barrier(2) + + def worker(doc_id: str): + barrier.wait() + return manager_harness.manager._upsert_doc( + doc_id=doc_id, + filename='serialized.txt', + path=file_path, + metadata={}, + source_type=SourceType.API, + upload_status=DocStatus.SUCCESS, + ) + + with ThreadPoolExecutor(max_workers=2) as pool: + futures = [pool.submit(worker, doc_id) for doc_id in ('serialized-a', 'serialized-b')] + + results = [] + errors = [] + for future in futures: + try: + results.append(future.result(timeout=2)) + except Exception as exc: + errors.append(exc) + + assert len(results) == 1 + assert len(errors) == 1 + assert isinstance(errors[0], DocServiceError) + assert errors[0].biz_code == 'E_STATE_CONFLICT' + + with manager_harness.manager._db_manager.get_session() as session: + Doc = manager_harness.manager._db_manager.get_table_orm_class('lazyllm_documents') + assert session.query(Doc).filter(Doc.path == file_path).count() == 1 + + def test_manager_upload_callback_and_doc_detail(manager_harness): manager_harness.manager.create_kb('kb_upload', algo_id='__default__') file_path = manager_harness.make_file('upload.txt', 'upload content') @@ -397,6 +434,46 @@ def test_manager_transfer_move_cleans_source_doc_with_target_doc_id(manager_harn assert manager_harness.manager._get_parse_snapshot('source-doc-move', 'kb_move_source', '__default__') is None +def test_manager_purge_local_and_rebind_keep_doc_consistent(manager_harness): + manager_harness.manager.create_kb('kb_purge_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_purge_target', algo_id='__default__') + file_path = manager_harness.make_file('purge-rebind.txt', 'purge rebind content') + uploaded = manager_harness.manager.upload(UploadRequest( + kb_id='kb_purge_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='shared-doc')], + )) + manager_harness.finish_task(uploaded[0]['task_id']) + + barrier = threading.Barrier(2) + + def purge(): + barrier.wait() + manager_harness.manager._purge_deleted_kb_doc_data('kb_purge_source', 'shared-doc', remove_relation=True) + + def rebind(): + barrier.wait() + return manager_harness.manager.upload(UploadRequest( + kb_id='kb_purge_target', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='shared-doc')], + )) + + with ThreadPoolExecutor(max_workers=2) as pool: + purge_future = pool.submit(purge) + rebind_future = pool.submit(rebind) + purge_future.result(timeout=2) + rebound = rebind_future.result(timeout=2) + + manager_harness.finish_task(rebound[0]['task_id']) + + assert manager_harness.manager._has_kb_document('kb_purge_source', 'shared-doc') is False + assert manager_harness.manager._has_kb_document('kb_purge_target', 'shared-doc') is True + doc = manager_harness.manager._get_doc('shared-doc') + assert doc is not None + assert doc['path'] == file_path + + def test_manager_transfer_target_fields_override_source_defaults(manager_harness): manager_harness.manager.create_kb('kb_override_source', algo_id='__default__') manager_harness.manager.create_kb('kb_override_target', algo_id='__default__')