diff --git a/.gitignore b/.gitignore index 0e34cdd67..1b2ff5977 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ __pycache__ test/ dist/ tmp/ +tmp*/ build *.lock *.db diff --git a/docs/en/API Reference/tools.md b/docs/en/API Reference/tools.md index 3be3ecb24..e31bdf5e1 100644 --- a/docs/en/API Reference/tools.md +++ b/docs/en/API Reference/tools.md @@ -200,35 +200,29 @@ members: [find] exclude-members: -::: lazyllm.tools.rag.DocManager - members: document, list_kb_groups, add_files, reparse_files - exclude-members: +::: lazyllm.tools.rag.doc_service.DocServer + members: [upload, add, reparse, delete, transfer, patch_metadata, list_docs, get_doc, list_tasks, get_task, cancel_task, list_kbs, get_kb, list_chunks, list_algorithms, get_algorithm_info, create_kb, update_kb, batch_get_kbs, delete_kb, delete_kbs] + exclude-members: -::: lazyllm.tools.rag.utils.SqliteDocListManager - members: - - table_inited - - get_status_cond_and_params - - validate_paths - - update_need_reparsing - - list_files - - get_docs - - set_docs_new_meta - - fetch_docs_changed_meta - - list_all_kb_group - - add_kb_group - - list_kb_group_files - - delete_unreferenced_doc - - get_docs_need_reparse - - get_existing_paths_by_pattern - - update_file_message - - update_file_status - - add_files_to_kb_group - - delete_files_from_kb_group - - get_file_status - - update_kb_group - - release - - get_status_cond_and_params - exclude-members: +::: lazyllm.tools.rag.doc_service.base.AddFileItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.UploadRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.AddRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferRequest + members: + exclude-members: ::: lazyllm.tools.rag.data_loaders.DirectoryReader members: load_data @@ -425,9 +419,6 @@ ::: lazyllm.tools.rag.rerank.ModuleReranker members: forward exclude-members: -::: lazyllm.tools.rag.utils.DocListManager - members: - exclude-members: ::: lazyllm.tools.rag.global_metadata.GlobalMetadataDesc members: exclude-members: diff --git a/docs/lazyllm-skill/assets/rag/document.md b/docs/lazyllm-skill/assets/rag/document.md index 26e1f1eaa..7cee6a46b 100644 --- a/docs/lazyllm-skill/assets/rag/document.md +++ b/docs/lazyllm-skill/assets/rag/document.md @@ -91,8 +91,8 @@ def processYml(file): ... print("Call the function processYml.") ... return [DocNode(text=data)] ... -doc1 = Document(dataset_path="your_files_path", create_ui=False) -doc2 = Document(dataset_path="your_files_path", create_ui=False) +doc1 = Document(dataset_path="your_files_path") +doc2 = Document(dataset_path="your_files_path") doc1.add_reader("**/*.yml", YmlReader) print(doc1._impl._local_file_reader) {'**/*.yml': } @@ -213,4 +213,3 @@ res = retriever(query=query) print(f"answer: {res}") ``` - diff --git a/docs/zh/API Reference/tools.md b/docs/zh/API Reference/tools.md index 1036a2fe4..90c13a0cf 100644 --- a/docs/zh/API Reference/tools.md +++ b/docs/zh/API Reference/tools.md @@ -189,35 +189,29 @@ members: [find] exclude-members: -::: lazyllm.tools.rag.DocManager - members: document, list_kb_groups, add_files, reparse_files - exclude-members: +::: lazyllm.tools.rag.doc_service.DocServer + members: [upload, add, reparse, delete, transfer, patch_metadata, list_docs, get_doc, list_tasks, get_task, cancel_task, list_kbs, get_kb, list_chunks, list_algorithms, get_algorithm_info, create_kb, update_kb, batch_get_kbs, delete_kb, delete_kbs] + exclude-members: -::: lazyllm.tools.rag.utils.SqliteDocListManager - members: - - table_inited - - get_status_cond_and_params - - validate_paths - - update_need_reparsing - - list_files - - get_docs - - set_docs_new_meta - - fetch_docs_changed_meta - - list_all_kb_group - - add_kb_group - - list_kb_group_files - - delete_unreferenced_doc - - get_docs_need_reparse - - get_existing_paths_by_pattern - - update_file_message - - update_file_status - - add_files_to_kb_group - - delete_files_from_kb_group - - get_file_status - - update_kb_group - - release - - get_status_cond_and_params - exclude-members: +::: lazyllm.tools.rag.doc_service.base.AddFileItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.UploadRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.AddRequest + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferItem + members: + exclude-members: + +::: lazyllm.tools.rag.doc_service.base.TransferRequest + members: + exclude-members: ::: lazyllm.tools.rag.data_loaders.DirectoryReader members: load_data @@ -414,9 +408,6 @@ ::: lazyllm.tools.rag.rerank.ModuleReranker members: forward exclude-members: -::: lazyllm.tools.rag.utils.DocListManager - members: - exclude-members: ::: lazyllm.tools.rag.global_metadata.GlobalMetadataDesc members: exclude-members: diff --git a/examples/rag/doc_service_mock_example.py b/examples/rag/doc_service_mock_example.py new file mode 100644 index 000000000..382652407 --- /dev/null +++ b/examples/rag/doc_service_mock_example.py @@ -0,0 +1,85 @@ +'''Connect a Document to a deployed DocServer. + +Start DocServer first: + python examples/rag/doc_service_standalone.py --wait + +Run this example: + python examples/rag/doc_service_mock_example.py --doc-server-url http://127.0.0.1:8848 +''' + +from __future__ import annotations + +import argparse +import os +import tempfile +import time + +from lazyllm import Document +from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.base import AddFileItem, AddRequest + + +def _normalize_base_url(url: str) -> str: + url = url.rstrip('/') + if url.endswith('/_call') or url.endswith('/generate'): + return url.rsplit('/', 1)[0] + return url + + +def _wait_task(server: DocServer, task_id: str, timeout: float = 30.0): + deadline = time.time() + timeout + while time.time() < deadline: + task = server.get_task(task_id)['data'] + if task['status'] in {'SUCCESS', 'FAILED', 'CANCELED', 'DELETED'}: + return task + time.sleep(0.5) + raise TimeoutError(f'task {task_id} did not finish in time') + + +def main(): + parser = argparse.ArgumentParser(description='Connect a Document to an existing DocServer.') + parser.add_argument('--doc-server-url', type=str, required=True, help='Existing DocServer base URL.') + parser.add_argument('--algo-id', type=str, default='doc_service_demo_algo', help='Document algorithm ID.') + parser.add_argument('--kb-id', type=str, default='__default__', help='Knowledge base ID.') + args = parser.parse_args() + + doc_server = DocServer(url=_normalize_base_url(args.doc_server_url)) + with tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_example_') as dataset_dir: + file_path = os.path.join(dataset_dir, 'demo.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write('hello from a real doc_service example\n') + + # Step 1: create a Document and bind it to the deployed DocServer. + document = Document(dataset_path=dataset_dir, manager=doc_server, name=args.algo_id) + document.start() + + try: + base_url = _normalize_base_url(args.doc_server_url) + print(f'DocServer URL: {base_url}') + print(f'DocServer Docs: {base_url}/docs') + + # Step 2: add a local file through the DocServer client. + response = doc_server.add(AddRequest( + kb_id=args.kb_id, + algo_id=args.algo_id, + items=[AddFileItem(file_path=file_path)], + )) + item = response['data']['items'][0] + print(f'Doc ID: {item["doc_id"]}') + print(f'Task ID: {item["task_id"]}') + + # Step 3: wait for the asynchronous parse task to finish. + task = _wait_task(doc_server, item['task_id']) + print(f'Task Status: {task["status"]}') + + # Step 4: list documents from the target knowledge base. + docs = doc_server.list_docs( + kb_id=args.kb_id, algo_id=args.algo_id, include_deleted_or_canceled=False + )['data']['items'] + print(f'Doc Count In {args.kb_id}: {len(docs)}') + finally: + document.stop() + + +if __name__ == '__main__': + main() diff --git a/examples/rag/doc_service_standalone.py b/examples/rag/doc_service_standalone.py new file mode 100644 index 000000000..8d98d1f24 --- /dev/null +++ b/examples/rag/doc_service_standalone.py @@ -0,0 +1,199 @@ +'''Start a standalone DocServer example. + +Examples: + python examples/rag/doc_service_standalone.py --wait + python examples/rag/doc_service_standalone.py --parser-url http://127.0.0.1:9966 --wait + python examples/rag/doc_service_standalone.py --export-openapi +''' + +from __future__ import annotations + +import argparse +import os +import time +from typing import Any, Dict + +import requests + +from lazyllm import Document +from lazyllm.tools.rag.doc_service import DocServer +from lazyllm.tools.rag.doc_service.doc_server import DEFAULT_OPENAPI_OUTPUT_PATH +from lazyllm.tools.rag.parsing_service import DocumentProcessor + +REAL_ALGO_ID = 'real-standalone-algo' +FIXED_DB_ROOT = './tmp/db' +DEFAULT_OPENAPI_PATH = DEFAULT_OPENAPI_OUTPUT_PATH + + +def _make_db_config(db_name: str) -> Dict[str, Any]: + return {'db_type': 'sqlite', 'user': None, 'password': None, 'host': None, 'port': None, 'db_name': db_name} + + +def _prepare_runtime_paths() -> Dict[str, str]: + os.makedirs(FIXED_DB_ROOT, exist_ok=True) + paths = { + 'root_dir': FIXED_DB_ROOT, + 'storage_dir': os.path.join(FIXED_DB_ROOT, 'uploads'), + 'store_dir': os.path.join(FIXED_DB_ROOT, 'store'), + 'doc_db': os.path.join(FIXED_DB_ROOT, 'doc_service.sqlite'), + 'parser_db': os.path.join(FIXED_DB_ROOT, 'parser.sqlite'), + } + os.makedirs(paths['storage_dir'], exist_ok=True) + os.makedirs(paths['store_dir'], exist_ok=True) + return paths + + +def _wait_http_ok(url: str, timeout: float = 20.0): + deadline = time.time() + timeout + while time.time() < deadline: + try: + response = requests.get(url, timeout=3) + if response.status_code == 200: + return + except Exception: + pass + time.sleep(0.2) + raise RuntimeError(f'http service is not ready: {url}') + + +def _build_store_conf(root_dir: str) -> Dict[str, Any]: + segment_store_path = os.path.join(root_dir, 'segments.db') + milvus_store_path = os.path.join(root_dir, 'milvus_lite.db') + open(segment_store_path, 'a', encoding='utf-8').close() + open(milvus_store_path, 'a', encoding='utf-8').close() + return { + 'segment_store': {'type': 'map', 'kwargs': {'uri': segment_store_path}}, + 'vector_store': { + 'type': 'milvus', + 'kwargs': { + 'uri': milvus_store_path, + 'index_kwargs': {'index_type': 'FLAT', 'metric_type': 'COSINE'}, + }, + }, + } + + +def _wait_algo_ready(parser_url: str, algo_id: str, timeout: float = 20.0): + deadline = time.time() + timeout + while time.time() < deadline: + response = requests.get(f'{parser_url}/algo/list', timeout=5) + if response.status_code == 200: + items = response.json().get('data', []) + if any(item.get('algo_id') == algo_id for item in items): + return + time.sleep(0.2) + raise RuntimeError(f'algorithm is not ready: {algo_id}') + + +def _register_demo_algorithm(parser_url: str, algo_id: str, store_dir: str): + # Step 2: register a real Document algorithm on the parsing service. + document = Document( + dataset_path=None, + name=algo_id, + embed={'vec_dense': lambda text: [1.0, 2.0, 3.0]}, + store_conf=_build_store_conf(store_dir), + display_name='Standalone Real Algo', + manager=DocumentProcessor(url=parser_url), + description='Algorithm registered by standalone doc service example', + ) + document.create_node_group( + name='line', + transform=lambda text: text.split('\n'), + parent='CoarseChunk', + display_name='Line Chunk', + ) + document.activate_group('CoarseChunk', embed_keys=['vec_dense']) + document.activate_group('line', embed_keys=['vec_dense']) + document.start() + _wait_algo_ready(parser_url, algo_id) + return document + + +def _start_local_parser(parser_port: int, parser_db: str): + # Step 1: start a local parsing service. + parser = DocumentProcessor(port=parser_port, db_config=_make_db_config(parser_db), num_workers=1) + parser.start() + parser_url = parser._impl._url.rsplit('/', 1)[0] + _wait_http_ok(f'{parser_url}/health') + return parser, parser_url + + +def main(): + parser = argparse.ArgumentParser(description='Standalone DocServer example.') + parser.add_argument('--port', type=int, default=8848, help='DocServer listen port.') + parser.add_argument('--parser-port', type=int, default=9966, help='DocumentProcessor listen port.') + parser.add_argument('--parser-url', type=str, default=None, help='Existing parsing service base URL.') + parser.add_argument('--algo-id', type=str, default=REAL_ALGO_ID, help='Algorithm ID for the local demo setup.') + parser.add_argument('--wait', action='store_true', help='Keep the example running for manual inspection.') + parser.add_argument( + '--export-openapi', + type=str, + nargs='?', + const=DEFAULT_OPENAPI_PATH, + default=None, + help=f'Export DocServer OpenAPI JSON and exit. Default path: {DEFAULT_OPENAPI_PATH}', + ) + args = parser.parse_args() + + if args.export_openapi: + print(f'OpenAPI exported: {DocServer.export_openapi(args.export_openapi)}', flush=True) + return + + paths = _prepare_runtime_paths() + parser_server = None + server = None + document = None + parser_url = args.parser_url + + try: + if parser_url: + parser_url = parser_url.rstrip('/') + _wait_http_ok(f'{parser_url}/health') + else: + parser_server, parser_url = _start_local_parser(args.parser_port, paths['parser_db']) + document = _register_demo_algorithm(parser_url, args.algo_id, paths['store_dir']) + + # Step 3: start DocServer and point it to the parsing service. + server = DocServer( + storage_dir=paths['storage_dir'], + db_config=_make_db_config(paths['doc_db']), + parser_url=parser_url, + port=args.port, + ) + server.start() + base_url = server.url.rsplit('/', 1)[0] + _wait_http_ok(f'{base_url}/v1/health') + + print(f'DocServer URL: {base_url}', flush=True) + print(f'DocServer Docs: {base_url}/docs', flush=True) + print(f'Parser URL: {parser_url}', flush=True) + print(f'Storage Dir: {paths["storage_dir"]}', flush=True) + print(f'Doc DB: {paths["doc_db"]}', flush=True) + if document: + print(f'Algorithm ID: {args.algo_id}', flush=True) + + if args.wait: + # Step 4: keep the services alive for manual API testing. + print('Services are running. Press Ctrl+C to stop.', flush=True) + while True: + time.sleep(1) + finally: + if server: + try: + server.stop() + except Exception: + pass + if document: + try: + document.stop() + except Exception: + pass + if parser_server: + try: + parser_server.stop() + except Exception: + pass + + +if __name__ == '__main__': + main() diff --git a/examples/rag_with_parsing_service/README.CN.md b/examples/rag_with_parsing_service/README.CN.md index 787e40d74..20d92fe44 100644 --- a/examples/rag_with_parsing_service/README.CN.md +++ b/examples/rag_with_parsing_service/README.CN.md @@ -23,6 +23,8 @@ RAG 解析服务示例 参考 `document.py` 的配置: - `manager=DocumentProcessor(url="http://0.0.0.0:9966")` 指向解析服务。 +- 该模式下 `Document` 不会创建 `DocServer` 或 UI,`dataset_path` 也不会再用于本地扫盘。 +- 必须显式提供 `store_conf`,且不能使用纯 map store。 - `server=9977` 将 Document 作为服务暴露。 然后在 `retriever_using_url.py` 中使用: diff --git a/examples/rag_with_parsing_service/README.md b/examples/rag_with_parsing_service/README.md index bfd7c040e..441844351 100644 --- a/examples/rag_with_parsing_service/README.md +++ b/examples/rag_with_parsing_service/README.md @@ -26,6 +26,8 @@ others can access it remotely by URL, and Retrievers can use that URL directly. See `document.py` for the setup: - `manager=DocumentProcessor(url="http://0.0.0.0:9966")` points to the parsing service. +- In this mode `Document` does not create a `DocServer` or UI, and `dataset_path` is not used for local path monitoring. +- `store_conf` is required and must not be a pure map store. - `server=9977` exposes the Document as a service. Then use `retriever_using_url.py` to create: diff --git a/lazyllm/common/utils.py b/lazyllm/common/utils.py index bfed8e206..68745f06e 100644 --- a/lazyllm/common/utils.py +++ b/lazyllm/common/utils.py @@ -3,6 +3,7 @@ from typing import Union, Dict, Callable, Any, Optional import re import os +import sys from contextlib import contextmanager import cloudpickle import ast @@ -161,11 +162,37 @@ def str2bool(v: str) -> bool: raise argparse.ArgumentTypeError('Boolean value expected.') def dump_obj(f): + def _collect_test_modules(obj): + modules = [] + seen = set() + candidates = [obj] + if hasattr(obj, '__dict__'): + candidates.extend(obj.__dict__.values()) + for candidate in candidates: + module_name = getattr(candidate, '__module__', None) + if not module_name: + continue + if not (module_name.startswith('test_') or module_name.startswith('tmp.tests.')): + continue + module = sys.modules.get(module_name) + if module is None or module_name in seen: + continue + seen.add(module_name) + modules.append(module) + return modules + @contextmanager def env_helper(): os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'ON' - yield - os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'OFF' + modules = _collect_test_modules(f) + for module in modules: + cloudpickle.register_pickle_by_value(module) + try: + yield + finally: + for module in modules: + cloudpickle.unregister_pickle_by_value(module) + os.environ['LAZYLLM_ON_CLOUDPICKLE'] = 'OFF' with env_helper(): return None if f is None else base64.b64encode(cloudpickle.dumps(f)).decode('utf-8') diff --git a/lazyllm/components/deploy/relay/server.py b/lazyllm/components/deploy/relay/server.py index dad7e3021..51e713562 100644 --- a/lazyllm/components/deploy/relay/server.py +++ b/lazyllm/components/deploy/relay/server.py @@ -1,26 +1,44 @@ -from lazyllm.common.utils import str2obj -import uvicorn import argparse -import os -import sys +import asyncio +import codecs import inspect -import traceback -from types import GeneratorType -import lazyllm -from lazyllm import kwargs, package, load_obj -from lazyllm import FastapiApp, globals -from lazyllm.common import _trim_traceback, _register_trim_module -import time +import os import pickle -import codecs -import asyncio import functools +import sys +import time +import traceback from functools import partial +from types import GeneratorType from typing import Callable -from fastapi import FastAPI, Request -from fastapi.responses import Response, StreamingResponse -import requests + +def _inject_pythonpath(argv): + pythonpath = None + for index, arg in enumerate(argv): + if arg == '--pythonpath' and index + 1 < len(argv): + pythonpath = argv[index + 1] + break + if arg.startswith('--pythonpath='): + pythonpath = arg.split('=', 1)[1] + break + if pythonpath: + pythonpath = os.path.abspath(pythonpath) + if pythonpath in sys.path: + sys.path.remove(pythonpath) + sys.path.insert(0, pythonpath) + + +_inject_pythonpath(sys.argv[1:]) + +from lazyllm.common.utils import str2obj # noqa: E402 +import uvicorn # noqa: E402 +import lazyllm # noqa: E402 +from lazyllm import FastapiApp, globals, kwargs, load_obj, package # noqa: E402 +from lazyllm.common import _register_trim_module, _trim_traceback # noqa: E402 + +from lazyllm.thirdparty import fastapi # noqa: E402 +import requests # noqa: E402 # TODO(sunxiaoye): delete in the future lazyllm_module_dir = os.path.abspath(__file__) @@ -43,7 +61,10 @@ args = parser.parse_args() if args.pythonpath: - sys.path.append(args.pythonpath) + pythonpath = os.path.abspath(args.pythonpath) + if pythonpath in sys.path: + sys.path.remove(pythonpath) + sys.path.insert(0, pythonpath) func = load_obj(args.function) if args.before_function: @@ -58,7 +79,7 @@ _err_msg += (f' defined at `{load_obj(args.defined_pos)}`' if args.defined_pos else '') + ':\n' -app = FastAPI() +app = fastapi.FastAPI() FastapiApp.update() async def async_wrapper(func, *args, **kwargs): @@ -73,30 +94,30 @@ def impl(func, sid, global_data, *args, **kw): def security_check(f: Callable): @functools.wraps(f) - async def wrapper(request: Request): + async def wrapper(request: fastapi.Request): if args.security_key and args.security_key != request.headers.get('Security-Key'): - return Response(content='Authentication failed', status_code=401) + return fastapi.responses.Response(content='Authentication failed', status_code=401) return (await f(request)) if inspect.iscoroutinefunction(f) else f(request) return wrapper @app.post('/_call') @security_check -async def lazyllm_call(request: Request): +async def lazyllm_call(request: fastapi.Request): try: fname, args, kwargs = await request.json() args, kwargs = load_obj(args), load_obj(kwargs) r = await async_wrapper(getattr(func, fname), *args, **kwargs) - return Response(content=codecs.encode(pickle.dumps(r), 'base64')) + return fastapi.responses.Response(content=codecs.encode(pickle.dumps(r), 'base64')) except requests.RequestException as e: - return Response(content=f'{str(e)}', status_code=500) + return fastapi.responses.Response(content=f'{str(e)}', status_code=500) except Exception: exc_type, exc_value, exc_tb = sys.exc_info() formatted = ''.join(traceback.format_exception(exc_type, exc_value, _trim_traceback(exc_tb))) - return Response(content=f'{_err_msg}\n{formatted}', status_code=500) + return fastapi.responses.Response(content=f'{_err_msg}\n{formatted}', status_code=500) @app.post('/generate') @security_check -async def generate(request: Request): # noqa C901 +async def generate(request: fastapi.Request): # noqa C901 try: input, kw = (await request.json()), {} try: @@ -135,7 +156,7 @@ def impl(o): def generate_stream(): for o in output: yield impl(o) - return StreamingResponse(generate_stream(), media_type='text_plain') + return fastapi.responses.StreamingResponse(generate_stream(), media_type='text_plain') elif args.after_function: assert (callable(after_func)), 'after_func must be callable' r = inspect.getfullargspec(after_func) @@ -147,13 +168,13 @@ def generate_stream(): after_func(output, **{r.kwonlyargs[0]: origin}) elif len(new_args) == 2: output = after_func(output, origin) - return Response(content=impl(output)) + return fastapi.responses.Response(content=impl(output)) except requests.RequestException as e: - return Response(content=f'{str(e)}', status_code=500) + return fastapi.responses.Response(content=f'{str(e)}', status_code=500) except Exception: exc_type, exc_value, exc_tb = sys.exc_info() formatted = ''.join(traceback.format_exception(exc_type, exc_value, _trim_traceback(exc_tb))) - return Response(content=f'{_err_msg}\n{formatted}', status_code=500) + return fastapi.responses.Response(content=f'{_err_msg}\n{formatted}', status_code=500) finally: globals.clear() diff --git a/lazyllm/docs/tools/tool_rag.py b/lazyllm/docs/tools/tool_rag.py index bd022c1b7..c7e62c1ae 100644 --- a/lazyllm/docs/tools/tool_rag.py +++ b/lazyllm/docs/tools/tool_rag.py @@ -14,17 +14,18 @@ Args: dataset_path (Optional[str]): Path to the dataset directory. If not found, the system will attempt to locate it in ``lazyllm.config["data_path"]``. embed (Optional[Union[Callable, Dict[str, Callable]]]): Embedding function or mapping of embedding functions. When a dictionary is provided, keys are embedding names and values are embedding models. - manager (Union[bool, str], optional): Whether to enable the document manager. If ``True``, launches a manager service. If ``'ui'``, also enables the document management web UI. Defaults to ``False``. + create_ui (bool, optional): Whether to create the document-management UI. It requires an available ``DocServer`` and can be combined with ``manager=True`` or ``manager=DocServer(...)``. + manager (Union[bool, str, DocServer, Document._Manager, DocumentProcessor], optional): Document manager mode. ``True`` launches a local ``DocServer`` together with a local parsing service. ``DocServer(...)`` connects an existing document-management service. ``DocumentProcessor(...)`` connects a parsing service only and requires a non-map ``store_conf``. ``'ui'`` is accepted as a deprecated compatibility alias for ``manager=True, create_ui=True``. server (Union[bool, int], optional): Whether to run a server interface for knowledge bases. ``True`` enables a default server, an integer specifies a custom port, and ``False`` disables it. Defaults to ``False``. name (Optional[str]): Name identifier for this document collection. Defaults to the system default name. launcher (Optional[Launcher]): Launcher instance for managing server processes. Defaults to a remote asynchronous launcher. - store_conf (Optional[Dict]): Storage configuration. Defaults to in-memory MapStore. - doc_fields (Optional[Dict[str, DocField]]): Metadata field configuration for storing and retrieving document attributes. - cloud (bool): Whether the dataset is stored in the cloud. Defaults to ``False``. doc_files (Optional[List[str]]): Temporary document files. When used, ``dataset_path`` must be ``None``. Only MapStore is supported in this mode. - processor (Optional[DocumentProcessor]): Document processing service. + doc_fields (Optional[Dict[str, DocField]]): Metadata field configuration for storing and retrieving document attributes. + store_conf (Optional[Dict]): Storage configuration. Defaults to in-memory MapStore. display_name (Optional[str]): Human-readable display name for this document module. Defaults to the collection name. description (Optional[str]): Description of the document collection. Defaults to ``"algorithm description"``. + schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): Optional schema extractor used for metadata schema analysis and registration. + enable_path_monitoring (Optional[bool]): Whether to watch the local dataset path for file additions and removals. Defaults to enabled only for local documents without ``DocServer``/``DocumentProcessor`` manager mode. ''') add_chinese_doc('Document', '''\ @@ -35,17 +36,226 @@ Args: dataset_path (Optional[str]): 数据集目录路径。如果路径不存在,系统会尝试在 ``lazyllm.config["data_path"]`` 中查找。 embed (Optional[Union[Callable, Dict[str, Callable]]]): 文档向量化函数或函数字典。若为字典,键为 embedding 名称,值为对应的模型。 - manager (Union[bool, str], optional): 是否启用文档管理服务。``True`` 表示启动管理服务;``'ui'`` 表示同时启动 Web 管理界面;默认 ``False``。 + create_ui (bool, optional): 是否创建文档管理 UI。该能力要求当前存在可用的 ``DocServer``,可与 ``manager=True`` 或 ``manager=DocServer(...)`` 组合使用。 + manager (Union[bool, str, DocServer, Document._Manager, DocumentProcessor], optional): 文档管理模式。``True`` 表示启动本地 ``DocServer`` 及其 parsing service;``DocServer(...)`` 表示连接已有文档管理服务;``DocumentProcessor(...)`` 表示仅连接解析服务,此时必须提供非 map 的 ``store_conf``;``'ui'`` 仅作为 ``manager=True, create_ui=True`` 的兼容写法保留。 server (Union[bool, int], optional): 是否为知识库运行服务接口。``True`` 表示启动默认服务;整型数值表示自定义端口;``False`` 表示关闭。默认为 ``False``。 name (Optional[str]): 文档集合的名称标识符。默认为系统默认名称。 launcher (Optional[Launcher]): 启动器实例,用于管理服务进程。默认使用远程异步启动器。 - store_conf (Optional[Dict]): 存储配置。默认使用内存中的 MapStore。 - doc_fields (Optional[Dict[str, DocField]]): 元数据字段配置,用于存储和检索文档属性。 - cloud (bool): 是否为云端数据集。默认为 ``False``。 doc_files (Optional[List[str]]): 临时文档文件列表。当使用此参数时,``dataset_path`` 必须为 ``None``,且仅支持 MapStore。 - processor (Optional[DocumentProcessor]): 文档处理服务。 + doc_fields (Optional[Dict[str, DocField]]): 元数据字段配置,用于存储和检索文档属性。 + store_conf (Optional[Dict]): 存储配置。默认使用内存中的 MapStore。 display_name (Optional[str]): 文档模块的可读显示名称。默认为集合名称。 description (Optional[str]): 文档集合的描述。默认为 ``"algorithm description"``。 + schema_extractor (Optional[Union[LLMBase, SchemaExtractor]]): 可选 schema extractor,用于元数据 schema 分析与注册。 + enable_path_monitoring (Optional[bool]): 是否监控本地数据目录的文件新增和删除。仅在未接入 ``DocServer`` / ``DocumentProcessor`` 的本地模式下默认开启。 +''') + +add_english_doc('DocServer', '''\ +Primary entry point of the document service. + +``DocServer`` manages document upload/add/reparse/delete flows, task tracking, knowledge-base management, +chunk inspection, and cross-kb transfer. It is the recommended replacement for the legacy ``DocManager`` / +``DocListManager`` APIs. + +Args: + port (Optional[int]): Local service port when starting an in-process server. + url (Optional[str]): Existing doc_service URL. When provided, the instance works as a remote client. + parser_url (Optional[str]): Parsing service URL used by the local doc_service instance. + db_config (Optional[Dict[str, Any]]): Metadata database configuration for doc_service. + parser_db_config (Optional[Dict[str, Any]]): Parsing task database configuration for the parsing service. + parser_poll_interval (float): Poll interval used by local parser coordination. + storage_dir (Optional[str]): Local storage directory for uploaded files. + callback_url (Optional[str]): Callback URL used to receive parsing task updates. + launcher: Launcher used to start local services. +''') + +add_chinese_doc('DocServer', '''\ +文档服务的主入口。 + +``DocServer`` 负责文档上传/添加/重解析/删除、任务跟踪、知识库管理、chunk 查看,以及跨知识库文档转移。 +它是 legacy ``DocManager`` / ``DocListManager`` API 的推荐替代方案。 + +Args: + port (Optional[int]): 本地启动服务时使用的端口。 + url (Optional[str]): 已存在的 doc_service 地址;提供后当前实例作为远程客户端使用。 + parser_url (Optional[str]): 本地 doc_service 使用的 parsing service 地址。 + db_config (Optional[Dict[str, Any]]): doc_service 元数据数据库配置。 + parser_db_config (Optional[Dict[str, Any]]): parsing service 任务数据库配置。 + parser_poll_interval (float): 本地解析协调使用的轮询间隔。 + storage_dir (Optional[str]): 上传文件保存目录。 + callback_url (Optional[str]): 接收解析任务回调的地址。 + launcher: 本地服务启动器。 +''') + +add_english_doc('DocServer.add', '''\ +Add existing local files through the ``/v1/docs/add`` endpoint. + +Use this method when the file paths are already accessible on the DocServer host. The request body is an +``AddRequest`` containing ``kb_id``, ``algo_id``, and ``items``. Each item can provide ``file_path``, +optional ``doc_id``, and optional ``metadata``. + +Returns: + Standard API response. ``data["items"]`` contains the accepted ``doc_id`` and asynchronous ``task_id``. +''') + +add_chinese_doc('DocServer.add', '''\ +通过 ``/v1/docs/add`` 接口添加服务端可直接访问的本地文件。 + +当文件路径已经对 DocServer 所在机器可见时,使用该方法。请求体为 ``AddRequest``,包含 ``kb_id``、``algo_id`` +和 ``items``。每个 item 可提供 ``file_path``,以及可选的 ``doc_id``、``metadata``。 + +Returns: + 标准 API 响应。``data["items"]`` 中包含接受后的 ``doc_id`` 和异步 ``task_id``。 +''') + +add_english_doc('DocServer.upload', '''\ +Upload files into DocServer-managed storage through the ``/v1/docs/upload`` flow. + +Use this method when you want DocServer to manage uploaded copies of the source files. The request body is an +``UploadRequest`` with ``kb_id``, ``algo_id``, and ``items``. Each item uses ``file_path`` as the local source +path and can optionally include ``doc_id`` or ``metadata``. + +Returns: + Standard API response. ``data["items"]`` contains the accepted ``doc_id`` and asynchronous ``task_id``. +''') + +add_chinese_doc('DocServer.upload', '''\ +通过 ``/v1/docs/upload`` 流程将文件上传到 DocServer 管理的存储目录。 + +当你希望由 DocServer 保存上传副本时,使用该方法。请求体为 ``UploadRequest``,包含 ``kb_id``、``algo_id`` +和 ``items``。每个 item 使用 ``file_path`` 作为本地源路径,也可以附带可选的 ``doc_id``、``metadata``。 + +Returns: + 标准 API 响应。``data["items"]`` 中包含接受后的 ``doc_id`` 和异步 ``task_id``。 +''') + +add_english_doc('DocServer.reparse', '''\ +Reparse existing documents through the ``/v1/docs/reparse`` endpoint. + +The request body is a ``ReparseRequest`` with ``kb_id``, ``algo_id``, and ``doc_ids``. Use it after metadata +or parsing configuration changes when you want to enqueue new parse tasks for existing documents. +''') + +add_chinese_doc('DocServer.reparse', '''\ +通过 ``/v1/docs/reparse`` 接口重新解析已有文档。 + +请求体为 ``ReparseRequest``,包含 ``kb_id``、``algo_id`` 和 ``doc_ids``。当元数据或解析配置变更后, +需要为已有文档重新入队解析任务时,可使用该方法。 +''') + +add_english_doc('DocServer.delete', '''\ +Delete documents from a knowledge base through the ``/v1/docs/delete`` endpoint. + +The request body is a ``DeleteRequest`` with ``kb_id``, ``algo_id``, and ``doc_ids``. Deletion is asynchronous, +so the returned ``task_id`` should be tracked through the task APIs when you need final status. +''') + +add_chinese_doc('DocServer.delete', '''\ +通过 ``/v1/docs/delete`` 接口从知识库中删除文档。 + +请求体为 ``DeleteRequest``,包含 ``kb_id``、``algo_id`` 和 ``doc_ids``。删除是异步操作,因此如果需要最终状态, +应继续通过任务接口跟踪返回的 ``task_id``。 +''') + +add_english_doc('DocServer.patch_metadata', '''\ +Patch document metadata through the ``/v1/docs/metadata/patch`` endpoint. + +The request body is a ``MetadataPatchRequest`` with ``kb_id``, ``algo_id``, and ``items``. Each item targets one +document and carries a partial metadata patch in ``patch``. +''') + +add_chinese_doc('DocServer.patch_metadata', '''\ +通过 ``/v1/docs/metadata/patch`` 接口更新文档元数据。 + +请求体为 ``MetadataPatchRequest``,包含 ``kb_id``、``algo_id`` 和 ``items``。每个 item 指向一个文档, +并在 ``patch`` 中携带需要合并的局部元数据。 +''') + +add_english_doc('DocServer.get_task', '''\ +Get one task record through the ``/v1/tasks/{task_id}`` endpoint. + +Args: + task_id (str): Task ID returned by add, upload, reparse, delete, transfer, or metadata patch operations. + +Returns: + Standard API response with the current task status and task payload. +''') + +add_chinese_doc('DocServer.get_task', '''\ +通过 ``/v1/tasks/{task_id}`` 接口获取单个任务记录。 + +Args: + task_id (str): add、upload、reparse、delete、transfer 或 metadata patch 等操作返回的任务 ID。 + +Returns: + 包含当前任务状态和任务负载的标准 API 响应。 +''') + +add_english_doc('DocServer.cancel_task', '''\ +Cancel a waiting task through the ``/v1/tasks/cancel`` endpoint. + +Args: + task_id (str): Task ID to cancel. + +Returns: + Standard API response indicating whether the task was canceled successfully. +''') + +add_chinese_doc('DocServer.cancel_task', '''\ +通过 ``/v1/tasks/cancel`` 接口取消一个处于等待中的任务。 + +Args: + task_id (str): 要取消的任务 ID。 + +Returns: + 表示任务是否取消成功的标准 API 响应。 +''') + +add_english_doc('DocServer.list_chunks', '''\ +List parsed chunks for a document through the ``/v1/chunks`` endpoint. + +Args: + kb_id (str): Knowledge-base ID. + doc_id (str): Source document ID. + group (str): Node group name to inspect. + algo_id (str): Algorithm ID. + page (int): 1-based page number. + page_size (int): Number of chunks per page. + offset (Optional[int]): Explicit offset. When omitted, the service derives it from ``page`` and ``page_size``. + +Returns: + Paginated chunk data including ``items`` and ``total``. +''') + +add_chinese_doc('DocServer.list_chunks', '''\ +通过 ``/v1/chunks`` 接口分页查看文档的解析 chunk。 + +Args: + kb_id (str): 知识库 ID。 + doc_id (str): 文档 ID。 + group (str): 要查看的节点组名。 + algo_id (str): 算法 ID。 + page (int): 从 1 开始的页码。 + page_size (int): 每页 chunk 数量。 + offset (Optional[int]): 显式偏移量;未传时服务端会根据 ``page`` 和 ``page_size`` 推导。 + +Returns: + 包含 ``items`` 与 ``total`` 的分页结果。 +''') + +add_english_doc('DocServer.transfer', '''\ +Transfer parsed documents between knowledge bases under the same algorithm. + +The request body is a ``TransferRequest``. Each transfer item must provide a unique ``target_doc_id`` in the target +knowledge base. Transfer across different algorithms is not supported. Optional ``target_filename`` and +``target_file_path`` can override the destination file name/path recorded for the transferred document. +''') + +add_chinese_doc('DocServer.transfer', '''\ +在同一算法下的不同知识库之间转移已解析文档。 + +请求体为 ``TransferRequest``。每个转移项都必须在目标知识库中提供唯一的 ``target_doc_id``。 +当前不支持跨算法 transfer。可选字段 ``target_filename`` 与 ``target_file_path`` 用于覆盖目标文档记录的文件名或文件路径。 ''') add_example('Document', '''\ @@ -507,8 +717,8 @@ ... data = f.read() ... return [DocNode(text=data)] ... ->>> doc1 = Document(dataset_path="your_files_path", create_ui=False) ->>> doc2 = Document(dataset_path="your_files_path", create_ui=False) +>>> doc1 = Document(dataset_path="your_files_path") +>>> doc2 = Document(dataset_path="your_files_path") >>> files = ["your_yml_files"] >>> docs1 = doc1._impl._reader.load_data(input_files=files) >>> docs2 = doc2._impl._reader.load_data(input_files=files) @@ -552,8 +762,8 @@ ... print("Call the function processYml.") ... return [DocNode(text=data)] ... ->>> doc1 = Document(dataset_path="your_files_path", create_ui=False) ->>> doc2 = Document(dataset_path="your_files_path", create_ui=False) +>>> doc1 = Document(dataset_path="your_files_path") +>>> doc2 = Document(dataset_path="your_files_path") >>> doc1.add_reader("**/*.yml", YmlReader) >>> print(doc1._impl._local_file_reader) {'**/*.yml': } @@ -992,7 +1202,7 @@ ... return [DocNode(text=data)] ... >>> files = ["your_yml_files"] ->>> doc = Document(dataset_path="your_files_path", create_ui=False) +>>> doc = Document(dataset_path="your_files_path") >>> reader = doc._impl._reader.load_data(input_files=files) # Call the class YmlReader. ''') @@ -1419,7 +1629,7 @@ 将算法/知识库绑定到指定 schema 集合;若提供 schema_set 会先注册;可选 force_refresh 覆盖已有绑定并清理旧数据。 Args: - algo_id (str, optional): 算法/Document 名称,默认 DocListManager.DEFAULT_GROUP_NAME。 + algo_id (str, optional): 算法/Document 名称,默认 `__default__`。 kb_id (str, optional): 知识库 ID,默认 DEFAULT_KB_ID。 schema_set_id (str, optional): 已有 schema 集合 ID。 schema_set (Type[BaseModel], optional): 新 schema,传入则会注册后绑定。 @@ -1433,7 +1643,7 @@ Bind an algo/kb pair to a schema set; optionally register a provided schema_set first; with force_refresh you can override an existing binding and purge old records. Args: - algo_id (str, optional): Algorithm/Document name, defaults to DocListManager.DEFAULT_GROUP_NAME. + algo_id (str, optional): Algorithm/Document name, defaults to `__default__`. kb_id (str, optional): Knowledge base id, defaults to DEFAULT_KB_ID. schema_set_id (str, optional): Existing schema set id to bind. schema_set (Type[BaseModel], optional): Schema to register and bind if no id is provided. @@ -1470,7 +1680,7 @@ Args: data (Union[str, List[DocNode]]): 文本或 DocNode 列表(需同一文档)。 - algo_id (str, optional): 算法/Document 名称,默认 DocListManager.DEFAULT_GROUP_NAME。 + algo_id (str, optional): 算法/Document 名称,默认 `__default__`。 schema_set_id (str, optional): 指定使用的 schema 集合 ID。 schema_set (Type[BaseModel], optional): 动态注册并使用的 schema。 @@ -1483,7 +1693,7 @@ Args: data (Union[str, List[DocNode]]): Text or list of DocNodes from a single document. - algo_id (str, optional): Algorithm/Document name, defaults to DocListManager.DEFAULT_GROUP_NAME. + algo_id (str, optional): Algorithm/Document name, defaults to `__default__`. schema_set_id (str, optional): Schema set id to use. schema_set (Type[BaseModel], optional): Schema to register and use if no id is provided. @@ -5591,10 +5801,10 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: # rag/doc_manager.py add_chinese_doc('rag.DocManager', """ -DocManager类管理文档列表及相关操作,并通过API提供文档上传、删除、分组等功能。 +已废弃。请改用 `lazyllm.tools.rag.doc_service.DocServer`。 Args: - dlm (DocListManager): 文档列表管理器,用于处理具体的文档操作。 + dlm (DocListManager): 旧文档列表管理器。 """) @@ -5756,10 +5966,10 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: """) add_english_doc('rag.DocManager', """ -The `DocManager` class manages document lists and related operations, providing APIs for uploading, deleting, and grouping documents. +Deprecated. Use `lazyllm.tools.rag.doc_service.DocServer` instead. Args: - dlm (DocListManager): Document list manager responsible for handling document-related operations. + dlm (DocListManager): Legacy document list manager. """) add_english_doc('rag.DocManager.document', """ @@ -6021,7 +6231,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: # rag/utils.py add_chinese_doc('rag.utils.DocListManager', """\ -抽象基类,用于管理文档列表和监控文档目录变化。 +已废弃。请改用 `Document(dataset_path=..., enable_path_monitoring=...)` 或 `DocServer`。 Args: path:要监控的文档目录路径。 @@ -6197,7 +6407,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: 说明: - 方法首先通过辅助函数 `_add_doc_records` 创建文档记录。 - - 文件添加后,会自动关联到默认的知识库组 (`DocListManager.DEFAULT_GROUP_NAME`)。 + - 文件添加后,会自动关联到默认的知识库组(`__default__`)。 - 批量处理确保在添加大量文件时具有良好的可扩展性。 ''') @@ -6315,7 +6525,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_english_doc('rag.utils.DocListManager', """\ -Abstract base class for managing document lists and monitoring changes in a document directory. +Deprecated. Use `Document(dataset_path=..., enable_path_monitoring=...)` or `DocServer`. Args: path: Path of the document directory to monitor. @@ -6469,7 +6679,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: If `details=True`, returns a list of detailed rows with additional metadata. Notes: - The method first creates document records using the `_add_doc_records` helper function. - - After the files are added, they are automatically linked to the default KB group (`DocListManager.DEFAULT_GROUP_NAME`). + - After the files are added, they are automatically linked to the default KB group (`__default__`). - Batch processing ensures scalability when adding a large number of files. ''') @@ -6488,7 +6698,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: Notes: - The method first creates document records using the helper function _add_doc_records. -- After the files are added, they are automatically linked to the default knowledge base group (DocListManager.DEFAULT_GROUP_NAME). +- After the files are added, they are automatically linked to the default knowledge base group (`__default__`). - Batch processing ensures good scalability when adding a large number of files. @@ -6608,15 +6818,14 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: add_example('rag.utils.DocListManager', ''' >>> import lazyllm ->>> from lazyllm.rag.utils import DocListManager ->>> manager = DocListManager(path='your_file_path/', name="test_manager", enable_path_monitoring=False) +>>> # Deprecated. Use Document(dataset_path='your_file_path', enable_path_monitoring=True) >>> added_docs = manager.add_files([test_file_list]) >>> manager.enable_path_monitoring(True) >>> deleted = manager.delete_files([delete_file_list]) ''') add_chinese_doc('rag.utils.SqliteDocListManager', '''\ -基于 SQLite 的文档管理器,用于本地文件的持久化存储、状态管理与元信息追踪。 +已废弃。请改用 `Document(dataset_path=...)` 或新的 `doc_service` 数据库。 该类继承自 DocListManager,利用 SQLite 数据库存储文档记录。适用于管理具有唯一标识符的本地文档资源,并提供便捷的插入、查询、更新与状态过滤接口,支持可选的路径监控功能。 @@ -6627,7 +6836,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_english_doc('rag.utils.SqliteDocListManager', '''\ -SQLite-based document manager for persistent local file storage, status tracking, and metadata management. +Deprecated. Use `Document(dataset_path=...)` or the new `doc_service` database instead. This class inherits from DocListManager and uses a SQLite backend to store document records. It is suitable for managing locally identified documents with support for inserting, querying, updating, and filtering based on status. Optional file path monitoring is also supported. @@ -6638,8 +6847,7 @@ def _lazy_load_data(self, file_paths: list, **kwargs) -> Iterable[DocNode]: ''') add_example('rag.utils.SqliteDocListManager', '''\ ->>> from lazyllm.tools.rag.utils import SqliteDocListManager ->>> manager = SqliteDocListManager(path="./data", name="docs.sqlite") +>>> # Deprecated. Use Document(dataset_path=...) or DocServer instead. >>> manager.insert({"uid": "doc_001", "name": "example.txt", "status": "ready"}) >>> print(manager.get("doc_001")) >>> files = manager.list_files(limit=5, details=True) diff --git a/lazyllm/tools/__init__.py b/lazyllm/tools/__init__.py index 9dd625ac9..62da11d3c 100644 --- a/lazyllm/tools/__init__.py +++ b/lazyllm/tools/__init__.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: # flake8: noqa: E401 from .rag import (Document, GraphDocument, UrlGraphDocument, Reranker, Retriever, TempDocRetriever, - GraphRetriever, SentenceSplitter, LLMParser) + GraphRetriever, SentenceSplitter, LLMParser, DocServer) from .webpages import WebModule from .fs import (LazyLLMFSBase, CloudFSBufferedFile, CloudFS, CloudFsWatchdog, FeishuFS, ConfluenceFS, NotionFS, GoogleDriveFS, OneDriveFS, YuqueFS, OnesFS, S3FS, @@ -55,6 +55,7 @@ def __getattr__(name: str): _SUBMOD_MAP = { 'rag': [ 'Document', + 'DocServer', 'GraphDocument', 'UrlGraphDocument', 'Reranker', diff --git a/lazyllm/tools/rag/__init__.py b/lazyllm/tools/rag/__init__.py index ec4489dab..03f3039df 100644 --- a/lazyllm/tools/rag/__init__.py +++ b/lazyllm/tools/rag/__init__.py @@ -1,8 +1,11 @@ from lazyllm.thirdparty import check_dependency_by_group + check_dependency_by_group('rag') # flake8: noqa: E402 from .document import Document +from .doc_manager import DocManager +from .doc_service import DocServer from .graph_document import GraphDocument, UrlGraphDocument from .retriever import Retriever, TempDocRetriever, ContextRetriever, WeightedRetriever, PriorityRetriever from .graph_retriever import GraphRetriever @@ -11,12 +14,11 @@ CharacterSplitter, RecursiveSplitter, MarkdownSplitter, CodeSplitter, JSONSplitter, YAMLSplitter, HTMLSplitter, XMLSplitter, GeneralCodeSplitter, JSONLSplitter) from .similarity import register_similarity -from .doc_node import DocNode +from .doc_node import DocNode, RichDocNode from .readers import (PDFReader, DocxReader, HWPReader, PPTXReader, ImageReader, IPYNBReader, EpubReader, MarkdownReader, MboxReader, PandasCSVReader, PandasExcelReader, VideoAudioReader, MineruPDFReader) from .dataReader import SimpleDirectoryReader, FileReader -from .doc_manager import DocManager, DocListManager from .global_metadata import GlobalMetadataDesc as DocField from .data_type import DataType from .index_base import IndexBase @@ -29,6 +31,8 @@ __all__ = [ 'add_post_action_for_default_reader', 'Document', + 'DocManager', + 'DocServer', 'GraphDocument', 'UrlGraphDocument', 'Reranker', @@ -46,6 +50,7 @@ 'register_similarity', 'register_reranker', 'DocNode', + 'RichDocNode', 'PDFReader', 'DocxReader', 'HWPReader', @@ -60,8 +65,6 @@ 'VideoAudioReader', 'SimpleDirectoryReader', 'MineruPDFReader', - 'DocManager', - 'DocListManager', 'DocField', 'DataType', 'IndexBase', diff --git a/lazyllm/tools/rag/doc_impl.py b/lazyllm/tools/rag/doc_impl.py index 2370329ec..b7301f443 100644 --- a/lazyllm/tools/rag/doc_impl.py +++ b/lazyllm/tools/rag/doc_impl.py @@ -1,10 +1,10 @@ -import json +import os import threading import time from enum import Enum from pydantic import BaseModel from typing import Callable, Dict, List, Optional, Set, Union, Tuple, Any, Type -from lazyllm import LOG, once_wrapper +from lazyllm import LOG, once_wrapper, reset_on_pickle from lazyllm.module import LLMBase from .transform import (NodeTransform, FuncNodeTransform, SentenceSplitter, LLMParser, TransformArgs, TransformArgs as TArgs) @@ -14,14 +14,13 @@ from .store.document_store import _DocumentStore from .doc_node import DocNode from .data_loaders import DirectoryReader -from .utils import DocListManager, is_sparse, _get_default_db_config -from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_KB_ID +from .utils import RAG_DEFAULT_GROUP_NAME, gen_docid, is_sparse, _get_default_db_config +from .global_metadata import GlobalMetadataDesc, RAG_DOC_ID, RAG_DOC_PATH, RAG_KB_ID from .data_type import DataType from .parsing_service import _Processor, DocumentProcessor from .embed_wrapper import _EmbedWrapper from .doc_to_db import SchemaExtractor from dataclasses import dataclass -from itertools import repeat _transmap = dict(function=FuncNodeTransform, sentencesplitter=SentenceSplitter, llm=LLMParser) @@ -64,13 +63,15 @@ def __hash__(self): return hash(self.name) lambda x: x._content, LAZY_IMAGE_GROUP, True) +@reset_on_pickle(('_local_monitor_lock', threading.Lock)) class DocImpl: _builtin_node_groups: Dict[str, Dict] = {} _global_node_groups: Dict[str, Dict] = {} _registered_file_reader: Dict[str, Callable] = {} - def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = None, - doc_files: Optional[str] = None, kb_group_name: Optional[str] = None, + def __init__(self, embed: Dict[str, Callable], dataset_path: Optional[str] = None, + enable_path_monitoring: bool = False, + doc_files: Optional[List[str]] = None, kb_group_name: Optional[str] = None, global_metadata_desc: Dict[str, GlobalMetadataDesc] = None, store: Optional[Union[Dict, LazyLLMStoreBase]] = None, processor: Optional[DocumentProcessor] = None, algo_name: Optional[str] = None, @@ -78,9 +79,15 @@ def __init__(self, embed: Dict[str, Callable], dlm: Optional[DocListManager] = N schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): super().__init__() self._local_file_reader: Dict[str, Callable] = {} - self._kb_group_name = kb_group_name or DocListManager.DEFAULT_GROUP_NAME - self._dlm, self._doc_files = dlm, doc_files + self._kb_group_name = kb_group_name or RAG_DEFAULT_GROUP_NAME + self._dataset_path = dataset_path + self._enable_path_monitoring = bool(enable_path_monitoring and dataset_path and doc_files is None) + self._doc_files = doc_files self._reader = DirectoryReader(None, self._local_file_reader, DocImpl._registered_file_reader) + self._local_monitor_thread: Optional[threading.Thread] = None + self._local_monitor_continue = False + self._local_monitor_interval = 10 + self._local_monitor_lock = threading.Lock() self.node_groups: Dict[str, Dict] = { LAZY_ROOT_NAME: dict(parent=None, display_name='Original Source', group_type=NodeGroupType.ORIGINAL), LAZY_IMAGE_GROUP: dict(parent=None, display_name='Image Node', group_type=NodeGroupType.OTHER) @@ -173,11 +180,11 @@ def _lazy_init(self) -> None: self._init_node_groups() self._create_schema_extractor() self._create_store() - cloud = not (self._dlm or self._doc_files is not None) + cloud = not (self._dataset_path or self._doc_files is not None) self._resolve_index_pending_registrations() if self._processor: - assert cloud and isinstance(self._processor, DocumentProcessor) + assert isinstance(self._processor, DocumentProcessor) self._processor.register_algorithm(self._algo_name, self._store, self._reader, self.node_groups, self._schema_extractor, self._display_name, self._description) else: @@ -185,18 +192,10 @@ def _lazy_init(self) -> None: self._schema_extractor, self._display_name, self._description) # init files when `cloud` is False - if not cloud and self._store.is_group_empty(LAZY_ROOT_NAME): - ids, pathes, metadatas = self._list_files(upload_status=DocListManager.Status.success) - self._processor.add_doc(pathes, ids, metadatas) - if pathes and self._dlm: - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.success) - if self._dlm: - self._init_monitor_event = threading.Event() - self._daemon = threading.Thread(target=self.worker) - self._daemon.daemon = True - self._daemon.start() - self._init_monitor_event.wait() + if not cloud: + self._sync_local_dataset() + if self._dataset_path and self._enable_path_monitoring: + self._start_local_monitoring() def _resolve_index_pending_registrations(self): for index_type, index_cls, index_args, index_kwargs in self._index_pending_registrations: @@ -298,93 +297,83 @@ def add_reader(self, pattern: str, func: Optional[Callable] = None): self._local_file_reader[pattern] = func self._reader._lazy_init.flag.reset() - def _add_doc_to_store_with_status(self, input_files: List[str], ids: List[str], metadatas: List[Dict[str, Any]], - cond_status_list: Optional[List[str]] = None): - success_ids, failed_ids = [], [] - for filepath, doc_id, metadata in zip(input_files, ids or repeat(None), metadatas or repeat(None)): + def _add_doc_to_store(self, input_files: List[str], ids: List[str], metadatas: List[Dict[str, Any]]): + for filepath, doc_id, metadata in zip(input_files, ids, metadatas): try: - self._add_doc_to_store(input_files=[filepath], ids=[doc_id] if doc_id is not None else None, - metadatas=[metadata] if metadata is not None else None) - success_ids.append(doc_id) + self._processor.add_doc([filepath], [doc_id], [metadata] if metadata is not None else None) except Exception as e: LOG.error(f'Error adding document {doc_id} ({filepath}) to store: {e}') - failed_ids.append(doc_id) - - if success_ids: - self._dlm.update_kb_group(cond_file_ids=success_ids, cond_group=self._kb_group_name, - cond_status_list=cond_status_list, new_status=DocListManager.Status.success) - if failed_ids: - self._dlm.update_kb_group(cond_file_ids=failed_ids, cond_group=self._kb_group_name, - cond_status_list=cond_status_list, new_status=DocListManager.Status.failed) - - def _batch_call(self, func: Callable, *args, batch_size: int = 10, **kwargs): - batch_count = next((len(arg) for arg in args if isinstance(arg, (tuple, list))), 0) - for i in range(0, batch_count, batch_size): - func(*[arg[i:i + batch_size] if isinstance(arg, (list, tuple)) else arg for arg in args], **kwargs) - - def worker(self): - is_first_run = True - while True: - # Apply meta changes - rows = self._dlm.fetch_docs_changed_meta(self._kb_group_name) - for row in rows: - new_meta_dict = json.loads(row[1]) if row[1] else {} - self._processor.update_doc_meta(doc_id=row[0], metadata=new_meta_dict) - - # Step 1: do doc-parsing, highest priority - docs = self._dlm.get_docs_need_reparse(group=self._kb_group_name) - if docs: - filepaths = [doc.path for doc in docs] - ids = [doc.doc_id for doc in docs] - metadatas = [json.loads(doc.meta) if doc.meta else None for doc in docs] - # update status and need_reparse - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.working, new_need_reparse=False) - self._delete_doc_from_store(doc_ids=ids) - self._batch_call(self._add_doc_to_store_with_status, filepaths, ids, metadatas, batch_size=10) - - # Step 2: After doc is deleted from related kb_group, delete doc from db - if self._kb_group_name == DocListManager.DEFAULT_GROUP_NAME: - self._dlm.delete_unreferenced_doc() - - # Step 3: do doc-deleting - ids, files, metadatas = self._list_files(status=DocListManager.Status.deleting) - if files: - self._delete_doc_from_store(doc_ids=ids) - self._dlm.delete_files_from_kb_group(ids, self._kb_group_name) - - # Step 4: do doc-adding - ids, files, metadatas = self._list_files(status=DocListManager.Status.waiting, - upload_status=DocListManager.Status.success) - if files: - self._dlm.update_kb_group(cond_file_ids=ids, cond_group=self._kb_group_name, - new_status=DocListManager.Status.working) - self._batch_call(self._add_doc_to_store_with_status, - files, ids, metadatas, cond_status_list=[DocListManager.Status.working]) - - if is_first_run: - self._init_monitor_event.set() - is_first_run = False - time.sleep(10) - - def _list_files( - self, status: Union[str, List[str]] = DocListManager.Status.all, - upload_status: Union[str, List[str]] = DocListManager.Status.all - ) -> Tuple[List[str], List[str], List[Dict]]: - if self._doc_files is not None: return None, self._doc_files, None - if not self._dlm: return [], [], [] - ids, paths, metadatas = [], [], [] - for row in self._dlm.list_kb_group_files(group=self._kb_group_name, status=status, - upload_status=upload_status, details=True): - ids.append(row[0]) - paths.append(row[1]) - metadatas.append(json.loads(row[3]) if row[3] else {}) - return ids, paths, metadatas - - def _add_doc_to_store(self, input_files: List[str], ids: Optional[List[str]] = None, - metadatas: Optional[List[Dict[str, Any]]] = None): - if not input_files: return - self._processor.add_doc(input_files, ids, metadatas) + + def _list_dataset_files(self) -> List[str]: + if not self._dataset_path or not os.path.exists(self._dataset_path): + return [] + if not os.path.isdir(self._dataset_path): + filename = os.path.basename(self._dataset_path) + if filename.startswith('.'): + return [] + return [self._dataset_path] if os.path.isfile(self._dataset_path) else [] + + files = [] + for root, dirs, names in os.walk(os.path.abspath(self._dataset_path)): + path_parts = root.split(os.sep) + if any(part.startswith('.') for part in path_parts if part): + continue + dirs[:] = [name for name in dirs if not name.startswith('.')] + files.extend(os.path.join(root, name) for name in names if not name.startswith('.')) + return sorted(files) + + def _list_local_files(self) -> Tuple[List[str], List[str], List[Dict[str, Any]]]: + paths = list(self._doc_files) if self._doc_files is not None else self._list_dataset_files() + ids = [gen_docid(path) for path in paths] + return ids, paths, [{} for _ in paths] + + def _list_store_root_docs(self) -> Dict[str, str]: + docs = {} + for node in self._store.get_nodes(group=LAZY_ROOT_NAME): + doc_id = node.global_metadata.get(RAG_DOC_ID) + path = node.global_metadata.get(RAG_DOC_PATH) + if doc_id and path: + docs.setdefault(path, doc_id) + return docs + + def _sync_local_dataset(self): + with self._local_monitor_lock: + ids, paths, metadatas = self._list_local_files() + current_docs = dict(zip(paths, ids)) + store_docs = self._list_store_root_docs() + + stale_doc_ids = [doc_id for path, doc_id in store_docs.items() if path not in current_docs] + if stale_doc_ids: + self._delete_doc_from_store(doc_ids=stale_doc_ids) + + pending_paths = [path for path in paths if path not in store_docs] + if pending_paths: + doc_id_map = dict(zip(paths, ids)) + metadata_map = dict(zip(paths, metadatas)) + self._add_doc_to_store( + pending_paths, + [doc_id_map[path] for path in pending_paths], + [metadata_map[path] for path in pending_paths], + ) + self._local_files = set(self._list_store_root_docs().keys()) + + def _monitor_local_dataset_worker(self): + while self._local_monitor_continue: + try: + current_paths = set(self._list_dataset_files()) + if current_paths != self._local_files: + self._sync_local_dataset() + except Exception as exc: # pragma: no cover - defensive + LOG.error(f'Failed to sync local dataset `{self._dataset_path}`: {exc}') + time.sleep(self._local_monitor_interval) + + def _start_local_monitoring(self): + if self._local_monitor_thread and self._local_monitor_thread.is_alive(): + return + self._local_monitor_continue = True + self._local_monitor_thread = threading.Thread(target=self._monitor_local_dataset_worker) + self._local_monitor_thread.daemon = True + self._local_monitor_thread.start() def _delete_doc_from_store(self, doc_ids: List[str] = None) -> None: self._processor.delete_doc(doc_ids=doc_ids) @@ -551,11 +540,15 @@ def _register_schema_set(self, schema_set: Type[BaseModel], kb_id: Optional[str] return set_id def _get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: self._lazy_init() - return self._store.get_nodes(uids=uids, doc_ids=doc_ids, group=group, - kb_id=kb_id, numbers=numbers, display=True) + return self._store.get_nodes( + uids=uids, doc_ids=doc_ids, group=group, kb_id=kb_id, numbers=numbers, + limit=limit, offset=offset, return_total=return_total, + sort_by_number=sort_by_number, display=True, + ) def _get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: diff --git a/lazyllm/tools/rag/doc_manager.py b/lazyllm/tools/rag/doc_manager.py index 7ef5d8157..214b725b9 100644 --- a/lazyllm/tools/rag/doc_manager.py +++ b/lazyllm/tools/rag/doc_manager.py @@ -7,13 +7,14 @@ from starlette.responses import RedirectResponse import lazyllm -from lazyllm import LOG, FastapiApp as app +from lazyllm import LOG, FastapiApp as app, deprecated from lazyllm.thirdparty import fastapi from .utils import DocListManager, BaseResponse, gen_docid from .global_metadata import RAG_DOC_ID, RAG_DOC_PATH import uuid +@deprecated('lazyllm.tools.rag.doc_service.DocServer') class DocManager(lazyllm.ModuleBase): def __init__(self, dlm: DocListManager) -> None: super().__init__() diff --git a/lazyllm/tools/rag/doc_service/__init__.py b/lazyllm/tools/rag/doc_service/__init__.py new file mode 100644 index 000000000..9f3ef78e9 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/__init__.py @@ -0,0 +1,4 @@ +from .doc_server import DocServer +from .doc_manager import DocManager + +__all__ = ['DocServer', 'DocManager'] diff --git a/lazyllm/tools/rag/doc_service/base.py b/lazyllm/tools/rag/doc_service/base.py new file mode 100644 index 000000000..c3e80fc0e --- /dev/null +++ b/lazyllm/tools/rag/doc_service/base.py @@ -0,0 +1,420 @@ +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional +from uuid import uuid4 + +from pydantic import AliasChoices, BaseModel, Field, model_validator +from ..parsing_service.base import TaskType + + +class DocStatus(str, Enum): + WAITING = 'WAITING' + WORKING = 'WORKING' + SUCCESS = 'SUCCESS' + FAILED = 'FAILED' + CANCELED = 'CANCELED' + DELETING = 'DELETING' + DELETED = 'DELETED' + + +class KBStatus(str, Enum): + ACTIVE = 'ACTIVE' + DELETING = 'DELETING' + DELETED = 'DELETED' + + +class SourceType(str, Enum): + API = 'API' + SCAN = 'SCAN' + TEMP = 'TEMP' + EXTERNAL = 'EXTERNAL' + + +class CallbackEventType(str, Enum): + START = 'START' + FINISH = 'FINISH' + + +BIZ_HTTP_CODE = { + 'E_INVALID_PARAM': 400, + 'E_NOT_FOUND': 404, + 'E_STATE_CONFLICT': 409, + 'E_IDEMPOTENCY_CONFLICT': 409, + 'E_IDEMPOTENCY_IN_PROGRESS': 409, +} + + +class DocServiceError(Exception): + def __init__(self, biz_code: str, msg: str, data: Optional[Dict[str, Any]] = None): + super().__init__(biz_code, msg, data) + self.biz_code = biz_code + self.msg = msg + self.data = data or {} + + @property + def http_status(self): + return BIZ_HTTP_CODE.get(self.biz_code, 500) + + +class TaskCallbackRequest(BaseModel): + callback_id: str = Field(default_factory=lambda: str(uuid4())) + task_id: str + event_type: CallbackEventType + status: DocStatus + error_code: Optional[str] = None + error_msg: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + + +class AddFileItem(BaseModel): + file_path: str + doc_id: Optional[str] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode='after') + def validate_file_path(self): + if not self.file_path or not self.file_path.strip(): + raise ValueError('file_path is required') + return self + + +class DocItemsRequest(BaseModel): + items: List[AddFileItem] + kb_id: str = '__default__' + algo_id: str = '__default__' + source_type: Optional[SourceType] = None + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + + +AddRequest = DocItemsRequest +UploadRequest = DocItemsRequest + + +class _DocMutationRequest(BaseModel): + doc_ids: List[str] + kb_id: str = '__default__' + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_doc_ids(self): + if not self.doc_ids: + raise ValueError('doc_ids is required') + return self + + +class ReparseRequest(_DocMutationRequest): + pass + + +class DeleteRequest(_DocMutationRequest): + pass + + +class TransferItem(BaseModel): + doc_id: str + target_doc_id: str + kb_id: str = Field(default='__default__', validation_alias=AliasChoices('kb_id', 'source_kb_id')) + algo_id: str = Field(default='__default__', validation_alias=AliasChoices('algo_id', 'source_algo_id')) + target_kb_id: str + target_algo_id: str + target_metadata: Optional[Dict[str, Any]] = None + target_filename: Optional[str] = None + target_file_path: Optional[str] = None + mode: str = 'copy' + + @model_validator(mode='before') + @classmethod + def normalize_source_fields(cls, data): + if not isinstance(data, dict): + return data + normalized = dict(data) + if 'kb_id' not in normalized and 'source_kb_id' in normalized: + normalized['kb_id'] = normalized['source_kb_id'] + if 'algo_id' not in normalized and 'source_algo_id' in normalized: + normalized['algo_id'] = normalized['source_algo_id'] + return normalized + + @property + def source_kb_id(self) -> str: + return self.kb_id + + @property + def source_algo_id(self) -> str: + return self.algo_id + + +class TransferRequest(BaseModel): + items: List[TransferItem] + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + + +class MetadataPatchItem(BaseModel): + doc_id: str + patch: Dict[str, Any] = Field(default_factory=dict) + + +class MetadataPatchRequest(BaseModel): + items: List[MetadataPatchItem] + kb_id: str = '__default__' + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_items(self): + if not self.items: + raise ValueError('items is required') + return self + + +class KbDeleteBatchRequest(BaseModel): + kb_ids: List[str] + idempotency_key: Optional[str] = None + + @model_validator(mode='after') + def validate_kb_ids(self): + if not self.kb_ids: + raise ValueError('kb_ids is required') + return self + + +class AlgorithmInfoRequest(BaseModel): + algo_id: str + + +class TaskInfoRequest(BaseModel): + task_id: str + + +class TaskBatchRequest(BaseModel): + task_ids: List[str] + + @model_validator(mode='after') + def validate_task_ids(self): + if not self.task_ids: + raise ValueError('task_ids is required') + return self + + +class TaskCancelRequest(BaseModel): + task_id: str + idempotency_key: Optional[str] = None + + +class TaskCallbackPayload(BaseModel): + callback_id: Optional[str] = None + task_id: Optional[str] = None + event_type: Optional[CallbackEventType] = None + status: Optional[DocStatus] = None + task_status: Optional[DocStatus] = None + error_code: Optional[str] = None + error_msg: Optional[str] = None + task_type: Optional[TaskType] = None + doc_id: Optional[str] = None + kb_id: Optional[str] = None + algo_id: Optional[str] = None + payload: Dict[str, Any] = Field(default_factory=dict) + + +class KbRequest(BaseModel): + kb_id: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None + owner_id: Optional[str] = None + meta: Optional[Dict[str, Any]] = None + algo_id: str = '__default__' + idempotency_key: Optional[str] = None + + +KbCreateRequest = KbRequest +KbUpdateRequest = KbRequest + + +class KbBatchQueryRequest(BaseModel): + kb_ids: List[str] + + @model_validator(mode='after') + def validate_kb_ids(self): + if not self.kb_ids: + raise ValueError('kb_ids is required') + return self + + +IDEMPOTENCY_RECORDS_TABLE_INFO = { + 'name': 'lazyllm_idempotency_records', + 'comment': 'Idempotency replay records', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'endpoint', 'data_type': 'string', 'nullable': False, 'comment': 'Endpoint name'}, + {'name': 'idempotency_key', 'data_type': 'string', 'nullable': False, 'comment': 'Idempotency key'}, + {'name': 'req_hash', 'data_type': 'string', 'nullable': False, 'comment': 'Request hash'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Processing status'}, + {'name': 'response_json', 'data_type': 'text', 'nullable': True, 'comment': 'Response json'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +CALLBACK_RECORDS_TABLE_INFO = { + 'name': 'lazyllm_callback_records', + 'comment': 'Processed callback records', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'callback_id', 'data_type': 'string', 'nullable': False, 'comment': 'Callback ID'}, + {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + ], +} + + +DOC_SERVICE_TASKS_TABLE_INFO = { + 'name': 'lazyllm_doc_service_tasks', + 'comment': 'Doc service task history table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'task_id', 'data_type': 'string', 'nullable': False, 'comment': 'Task ID'}, + {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Current task status'}, + {'name': 'message', 'data_type': 'text', 'nullable': True, 'comment': 'Task payload in JSON string'}, + {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Error code'}, + {'name': 'error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Error message'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, + {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, + ], +} + + +DOCUMENTS_TABLE_INFO = { + 'name': 'lazyllm_documents', + 'comment': 'Document metadata table', + 'columns': [ + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'is_primary_key': True, + 'comment': 'Document ID'}, + {'name': 'filename', 'data_type': 'string', 'nullable': False, 'comment': 'Filename'}, + {'name': 'path', 'data_type': 'string', 'nullable': False, 'comment': 'Absolute file path'}, + {'name': 'meta', 'data_type': 'text', 'nullable': True, 'comment': 'Document metadata in JSON string'}, + {'name': 'upload_status', 'data_type': 'string', 'nullable': False, 'comment': 'Document upload status'}, + {'name': 'source_type', 'data_type': 'string', 'nullable': False, 'comment': 'Source type'}, + {'name': 'file_type', 'data_type': 'string', 'nullable': True, 'comment': 'File type suffix'}, + {'name': 'content_hash', 'data_type': 'string', 'nullable': True, 'comment': 'Content hash'}, + {'name': 'size_bytes', 'data_type': 'integer', 'nullable': True, 'comment': 'File size in bytes'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KBS_TABLE_INFO = { + 'name': 'lazyllm_knowledge_bases', + 'comment': 'Knowledge base table', + 'columns': [ + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'is_primary_key': True, + 'comment': 'Knowledge base ID'}, + {'name': 'display_name', 'data_type': 'string', 'nullable': True, 'comment': 'Display name'}, + {'name': 'description', 'data_type': 'string', 'nullable': True, 'comment': 'Description'}, + {'name': 'doc_count', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Document count'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'default': KBStatus.ACTIVE.value, + 'comment': 'KB status'}, + {'name': 'owner_id', 'data_type': 'string', 'nullable': True, 'comment': 'Owner ID'}, + {'name': 'meta', 'data_type': 'text', 'nullable': True, 'comment': 'KB metadata in JSON string'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KB_DOCUMENTS_TABLE_INFO = { + 'name': 'lazyllm_kb_documents', + 'comment': 'KB and document binding table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +KB_ALGORITHM_TABLE_INFO = { + 'name': 'lazyllm_kb_algorithm', + 'comment': 'KB and algorithm binding table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} + + +PARSE_STATE_TABLE_INFO = { + 'name': 'lazyllm_doc_parse_state', + 'comment': 'Latest parse state snapshot table', + 'columns': [ + {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, + 'comment': 'Auto increment ID'}, + {'name': 'doc_id', 'data_type': 'string', 'nullable': False, 'comment': 'Document ID'}, + {'name': 'kb_id', 'data_type': 'string', 'nullable': False, 'comment': 'Knowledge base ID'}, + {'name': 'algo_id', 'data_type': 'string', 'nullable': False, 'comment': 'Algorithm ID'}, + {'name': 'status', 'data_type': 'string', 'nullable': False, 'comment': 'Current parse status'}, + {'name': 'current_task_id', 'data_type': 'string', 'nullable': True, 'comment': 'Current task ID'}, + {'name': 'task_type', 'data_type': 'string', 'nullable': True, 'comment': 'Current task type'}, + {'name': 'idempotency_key', 'data_type': 'string', 'nullable': True, 'comment': 'Idempotency key'}, + {'name': 'priority', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Task priority'}, + {'name': 'task_score', 'data_type': 'integer', 'nullable': True, 'comment': 'Task score'}, + {'name': 'retry_count', 'data_type': 'integer', 'nullable': False, 'default': 0, 'comment': 'Retry count'}, + {'name': 'max_retry', 'data_type': 'integer', 'nullable': False, 'default': 3, 'comment': 'Max retry'}, + {'name': 'lease_owner', 'data_type': 'string', 'nullable': True, 'comment': 'Lease owner'}, + {'name': 'lease_until', 'data_type': 'datetime', 'nullable': True, 'comment': 'Lease deadline'}, + {'name': 'last_error_code', 'data_type': 'string', 'nullable': True, 'comment': 'Last error code'}, + {'name': 'last_error_msg', 'data_type': 'text', 'nullable': True, 'comment': 'Last error message'}, + {'name': 'failed_stage', 'data_type': 'string', 'nullable': True, 'comment': 'Failure stage'}, + {'name': 'queued_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Queued time'}, + {'name': 'started_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Started time'}, + {'name': 'finished_at', 'data_type': 'datetime', 'nullable': True, 'comment': 'Finished time'}, + {'name': 'created_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Created time'}, + {'name': 'updated_at', 'data_type': 'datetime', 'nullable': False, 'default': datetime.now, + 'comment': 'Updated time'}, + ], +} diff --git a/lazyllm/tools/rag/doc_service/doc_manager.py b/lazyllm/tools/rag/doc_service/doc_manager.py new file mode 100644 index 000000000..a284464e7 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_manager.py @@ -0,0 +1,1807 @@ +from __future__ import annotations + +from datetime import datetime, timedelta +import json +import os +from typing import Any, Dict, List, Optional, Set +from uuid import uuid4 + +import sqlalchemy +from sqlalchemy.exc import IntegrityError + +from lazyllm import LOG +from lazyllm.thirdparty import fastapi + +from ..utils import BaseResponse, _get_default_db_config, _orm_to_dict +from ...sql import SqlManager +from .base import ( + AddRequest, CALLBACK_RECORDS_TABLE_INFO, CallbackEventType, DOC_SERVICE_TASKS_TABLE_INFO, + DOCUMENTS_TABLE_INFO, DeleteRequest, DocServiceError, DocStatus, IDEMPOTENCY_RECORDS_TABLE_INFO, + KB_ALGORITHM_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KBStatus, MetadataPatchRequest, + PARSE_STATE_TABLE_INFO, ReparseRequest, SourceType, TaskCallbackRequest, TaskType, TransferRequest, + UploadRequest, +) +from .parser_client import ParserClient +from .utils import ( + from_json, gen_doc_id, hash_payload, merge_transfer_metadata, resolve_transfer_target_path, + sha256_file, stable_json, to_json, +) + + +class DocManager: + def __init__( + self, + db_config: Optional[Dict[str, Any]] = None, + parser_url: Optional[str] = None, + callback_url: Optional[str] = None, + ): + if not parser_url: + raise ValueError('parser_url is required') + + self._db_config = db_config or _get_default_db_config('doc_service') + self._db_manager = SqlManager( + **self._db_config, + tables_info_dict={ + 'tables': [ + DOCUMENTS_TABLE_INFO, KBS_TABLE_INFO, KB_DOCUMENTS_TABLE_INFO, KB_ALGORITHM_TABLE_INFO, + PARSE_STATE_TABLE_INFO, IDEMPOTENCY_RECORDS_TABLE_INFO, CALLBACK_RECORDS_TABLE_INFO, + DOC_SERVICE_TASKS_TABLE_INFO, + ] + }, + ) + self._ensure_indexes() + self._parser_client = ParserClient(parser_url=parser_url) + self._parser_client.health() + self._callback_url = callback_url + + def set_callback_url(self, callback_url: str): + self._callback_url = callback_url + + def _ensure_indexes(self): + stmts = [ + 'DROP INDEX IF EXISTS uq_docs_path', + 'CREATE INDEX IF NOT EXISTS idx_docs_path ON lazyllm_documents(path)', + 'CREATE INDEX IF NOT EXISTS idx_documents_upload_status ON lazyllm_documents(upload_status)', + 'CREATE INDEX IF NOT EXISTS idx_documents_updated_at ON lazyllm_documents(updated_at)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_display_name ' + 'ON lazyllm_knowledge_bases(display_name) WHERE display_name IS NOT NULL', + 'CREATE INDEX IF NOT EXISTS idx_kb_created_at ON lazyllm_knowledge_bases(created_at)', + 'CREATE INDEX IF NOT EXISTS idx_kb_updated_at ON lazyllm_knowledge_bases(updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_kb_doc_count ON lazyllm_knowledge_bases(doc_count)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_documents ON lazyllm_kb_documents(kb_id, doc_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_documents_doc_id ON lazyllm_kb_documents(doc_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_documents_kb_id ON lazyllm_kb_documents(kb_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_kb_algorithm_kb_id ON lazyllm_kb_algorithm(kb_id)', + 'CREATE INDEX IF NOT EXISTS idx_kb_algorithm_algo_id ON lazyllm_kb_algorithm(algo_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_state_key ' + 'ON lazyllm_doc_parse_state(doc_id, kb_id, algo_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_parse_state_current_task ' + 'ON lazyllm_doc_parse_state(current_task_id) WHERE current_task_id IS NOT NULL', + 'CREATE INDEX IF NOT EXISTS idx_parse_sched ' + 'ON lazyllm_doc_parse_state(status, task_score, updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_parse_lease ' + 'ON lazyllm_doc_parse_state(status, lease_until)', + 'CREATE INDEX IF NOT EXISTS idx_parse_kb_algo_status ' + 'ON lazyllm_doc_parse_state(kb_id, algo_id, status)', + 'CREATE INDEX IF NOT EXISTS idx_parse_task_type_status ' + 'ON lazyllm_doc_parse_state(task_type, status, updated_at)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_idempotency_endpoint_key ' + 'ON lazyllm_idempotency_records(endpoint, idempotency_key)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_callback_id ' + 'ON lazyllm_callback_records(callback_id)', + 'CREATE UNIQUE INDEX IF NOT EXISTS uq_doc_service_task_id ' + 'ON lazyllm_doc_service_tasks(task_id)', + 'CREATE INDEX IF NOT EXISTS idx_doc_service_task_status ' + 'ON lazyllm_doc_service_tasks(status, updated_at)', + 'CREATE INDEX IF NOT EXISTS idx_doc_service_task_doc ' + 'ON lazyllm_doc_service_tasks(doc_id, kb_id, algo_id)', + ] + for stmt in stmts: + self._db_manager.execute_commit(stmt) + + def _ensure_default_kb(self): + self._ensure_kb('__default__', display_name='__default__') + self._ensure_kb_algorithm('__default__', '__default__') + self._cleanup_idempotency_records() + + def _ensure_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + update_fields: Optional[Set[str]] = None): + now = datetime.now() + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if row is None: + row = Kb(kb_id=kb_id, display_name=display_name, description=description, doc_count=0, + status=KBStatus.ACTIVE.value, owner_id=owner_id, meta=to_json(meta), + created_at=now, updated_at=now) + else: + if update_fields is None: + update_fields = set() + if 'display_name' in update_fields: + row.display_name = display_name + if 'description' in update_fields: + row.description = description + if 'owner_id' in update_fields: + row.owner_id = owner_id + if 'meta' in update_fields: + row.meta = to_json(meta) + if row.status == KBStatus.DELETED.value: + row.status = KBStatus.ACTIVE.value + row.updated_at = now + session.add(row) + + def _ensure_kb_algorithm(self, kb_id: str, algo_id: str): + now = datetime.now() + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id).first() + if row is None: + row = Rel(kb_id=kb_id, algo_id=algo_id, created_at=now, updated_at=now) + elif row.algo_id != algo_id: + raise DocServiceError('E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {row.algo_id}', + {'kb_id': kb_id, 'bound_algo_id': row.algo_id, + 'requested_algo_id': algo_id}) + else: + row.updated_at = now + session.add(row) + + def _get_kb(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + return _orm_to_dict(row) if row else None + + def _get_kb_algorithm(self, kb_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id).first() + return _orm_to_dict(row) if row else None + + @staticmethod + def _build_kb_data(kb_row, algo_row=None): + return { + 'kb_id': kb_row.kb_id, + 'display_name': kb_row.display_name, + 'description': kb_row.description, + 'doc_count': kb_row.doc_count, + 'status': kb_row.status, + 'owner_id': kb_row.owner_id, + 'meta': from_json(kb_row.meta), + 'algo_id': algo_row.algo_id if algo_row is not None else None, + 'created_at': kb_row.created_at, + 'updated_at': kb_row.updated_at, + } + + def _validate_kb_algorithm(self, kb_id: str, algo_id: str): + if kb_id == '__default__': + self._ensure_default_kb() + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + if kb.get('status') != KBStatus.ACTIVE.value: + raise DocServiceError('E_STATE_CONFLICT', f'kb is not active: {kb_id}', {'kb_id': kb_id}) + binding = self._get_kb_algorithm(kb_id) + if binding is None: + raise DocServiceError('E_STATE_CONFLICT', f'kb has no algorithm binding: {kb_id}', {'kb_id': kb_id}) + if binding['algo_id'] != algo_id: + raise DocServiceError( + 'E_INVALID_PARAM', f'kb {kb_id} is bound to algorithm {binding["algo_id"]}', + {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} + ) + return binding + + def _ensure_algorithm_exists(self, algo_id: str): + algorithms = self.list_algorithms() + if any(item.get('algo_id') == algo_id for item in algorithms): + return + raise DocServiceError('E_INVALID_PARAM', f'invalid algo_id: {algo_id}', {'algo_id': algo_id}) + + def _ensure_kb_document(self, kb_id: str, doc_id: str): + now = datetime.now() + created = False + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is None: + created = True + row = Rel(kb_id=kb_id, doc_id=doc_id, created_at=now, updated_at=now) + else: + row.updated_at = now + session.add(row) + if created: + self._refresh_kb_doc_count(kb_id) + return created + + def _remove_kb_document(self, kb_id: str, doc_id: str): + removed = False + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + row = session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() + if row is not None: + session.delete(row) + removed = True + if removed: + self._refresh_kb_doc_count(kb_id) + return removed + + def _load_idempotency_record(self, endpoint: str, idempotency_key: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + return _orm_to_dict(row) if row else None + + def _cleanup_idempotency_records(self, ttl_days: int = 7): + cutoff = datetime.now() - timedelta(days=ttl_days) + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + session.query(Record).filter(Record.updated_at < cutoff).delete() + + def _claim_idempotency_key(self, endpoint: str, idempotency_key: str, req_hash: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + now = datetime.now() + row = Record( + endpoint=endpoint, + idempotency_key=idempotency_key, + req_hash=req_hash, + status='PROCESSING', + response_json=None, + created_at=now, + updated_at=now, + ) + session.add(row) + session.flush() + return True + + def _complete_idempotency_record(self, endpoint: str, idempotency_key: str, response: Any): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + if row is None: + return + row.status = 'COMPLETED' + row.response_json = stable_json(response) + row.updated_at = datetime.now() + session.add(row) + + def _drop_idempotency_claim(self, endpoint: str, idempotency_key: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(IDEMPOTENCY_RECORDS_TABLE_INFO['name']) + row = session.query(Record).filter( + Record.endpoint == endpoint, + Record.idempotency_key == idempotency_key, + ).first() + if row is not None and row.status == 'PROCESSING': + session.delete(row) + + def run_idempotent(self, endpoint: str, idempotency_key: Optional[str], payload: Any, handler): + if not idempotency_key: + return handler() + req_hash = hash_payload(payload) + try: + self._claim_idempotency_key(endpoint, idempotency_key, req_hash) + except IntegrityError: + record = self._load_idempotency_record(endpoint, idempotency_key) + if record is None: + raise DocServiceError('E_IDEMPOTENCY_IN_PROGRESS', 'idempotency request is being processed') + if record['req_hash'] != req_hash: + raise DocServiceError('E_IDEMPOTENCY_CONFLICT', 'idempotency key conflicts with different request') + if record.get('status') == 'COMPLETED' and record.get('response_json'): + return json.loads(record['response_json']) + raise DocServiceError( + 'E_IDEMPOTENCY_IN_PROGRESS', 'idempotency request is being processed', + {'endpoint': endpoint, 'idempotency_key': idempotency_key} + ) + try: + response = handler() + except Exception: + self._drop_idempotency_claim(endpoint, idempotency_key) + raise + self._complete_idempotency_record(endpoint, idempotency_key, response) + return response + + def _record_callback(self, callback_id: str, task_id: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(CALLBACK_RECORDS_TABLE_INFO['name']) + session.add(Record(callback_id=callback_id, task_id=task_id, created_at=datetime.now())) + try: + session.flush() + return True + except IntegrityError: + session.rollback() + return False + + def _forget_callback_record(self, callback_id: str, task_id: str): + with self._db_manager.get_session() as session: + Record = self._db_manager.get_table_orm_class(CALLBACK_RECORDS_TABLE_INFO['name']) + session.query(Record).filter( + Record.callback_id == callback_id, + Record.task_id == task_id, + ).delete() + + def _create_task_record(self, task_id: str, task_type: TaskType, doc_id: str, kb_id: str, algo_id: str, + status: DocStatus, message: Optional[Dict[str, Any]] = None): + now = datetime.now() + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + session.add(Task( + task_id=task_id, + task_type=task_type.value, + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=status.value, + message=to_json(message), + error_code=None, + error_msg=None, + created_at=now, + updated_at=now, + started_at=None, + finished_at=None, + )) + return self._get_task_record(task_id) + + def _get_task_record(self, task_id: str): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + if row is None: + return None + task = _orm_to_dict(row) + task['message'] = from_json(task.get('message')) + return task + + def _update_task_record(self, task_id: str, **fields): + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + row = session.query(Task).filter(Task.task_id == task_id).first() + if row is None: + return None + for key, value in fields.items(): + setattr(row, key, value) + row.updated_at = datetime.now() + session.add(row) + return self._get_task_record(task_id) + + def _refresh_kb_doc_count(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is None: + return + kb_row.doc_count = session.query(Rel).filter(Rel.kb_id == kb_id).count() + if kb_row.status == KBStatus.DELETING.value and kb_row.doc_count == 0: + kb_row.status = KBStatus.DELETED.value + kb_row.updated_at = datetime.now() + session.add(kb_row) + + def _list_kb_doc_ids(self, kb_id: str) -> List[str]: + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + rows = session.query(Rel.doc_id).filter(Rel.kb_id == kb_id).all() + return [row[0] for row in rows] + + def _has_kb_document(self, kb_id: str, doc_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + return session.query(Rel).filter(Rel.kb_id == kb_id, Rel.doc_id == doc_id).first() is not None + + def _doc_relation_count(self, doc_id: str): + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + return session.query(Rel).filter(Rel.doc_id == doc_id).count() + + def _get_doc(self, doc_id: str): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + return _orm_to_dict(row) if row else None + + def _get_doc_by_path(self, path: str): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.path == path).first() + return _orm_to_dict(row) if row else None + + def _upsert_doc( + self, + doc_id: str, + filename: str, + path: str, + metadata: Dict[str, Any], + source_type: SourceType, + upload_status: DocStatus = DocStatus.SUCCESS, + allowed_path_doc_ids: Optional[Set[str]] = None, + ): + now = datetime.now() + file_type = os.path.splitext(path)[1].lstrip('.').lower() or None + size_bytes = os.path.getsize(path) if os.path.exists(path) else None + content_hash = sha256_file(path) if os.path.exists(path) else None + allowed_path_doc_ids = allowed_path_doc_ids or set() + try: + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + path_row = session.query(Doc).filter(Doc.path == path).first() + if path_row is not None and path_row.doc_id != doc_id and path_row.doc_id not in allowed_path_doc_ids: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': path_row.doc_id, 'path': path}, + ) + if row is None: + row = Doc( + doc_id=doc_id, + filename=filename, + path=path, + meta=to_json(metadata), + upload_status=upload_status.value, + source_type=source_type.value, + file_type=file_type, + content_hash=content_hash, + size_bytes=size_bytes, + created_at=now, + updated_at=now, + ) + else: + row.filename = filename + row.path = path + row.meta = to_json(metadata) + row.upload_status = upload_status.value + row.source_type = source_type.value + row.file_type = file_type + row.content_hash = content_hash + row.size_bytes = size_bytes + row.updated_at = now + session.add(row) + except IntegrityError as exc: + existing = self._get_doc_by_path(path) + if ( + existing is not None + and existing.get('doc_id') != doc_id + and existing.get('doc_id') not in allowed_path_doc_ids + ): + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc path already exists: {path}', + {'doc_id': existing['doc_id'], 'path': path}, + ) from exc + raise + return self._get_doc(doc_id) + + def _set_doc_upload_status(self, doc_id: str, status: DocStatus): + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return None + row.upload_status = status.value + row.updated_at = datetime.now() + session.add(row) + return self._get_doc(doc_id) + + def _get_parse_snapshot(self, doc_id: str, kb_id: str, algo_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id, State.algo_id == algo_id) + .first() + ) + return _orm_to_dict(row) if row else None + + def _get_latest_parse_snapshot(self, doc_id: str, kb_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id) + .order_by(State.updated_at.desc(), State.created_at.desc()) + .first() + ) + return _orm_to_dict(row) if row else None + + def _delete_parse_snapshots(self, doc_id: str, kb_id: str): + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + session.query(State).filter(State.doc_id == doc_id, State.kb_id == kb_id).delete() + + def _delete_doc_if_orphaned(self, doc_id: str) -> bool: + if self._doc_relation_count(doc_id) > 0: + return False + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if row is None: + return False + session.delete(row) + return True + + def _purge_deleted_kb_doc_data(self, kb_id: str, doc_id: str, remove_relation: bool = False): + if remove_relation: + self._remove_kb_document(kb_id, doc_id) + self._delete_parse_snapshots(doc_id, kb_id) + if not self._delete_doc_if_orphaned(doc_id): + self._sync_doc_upload_status(doc_id) + + def _mark_task_cleanup_policy(self, task_id: str, cleanup_policy: str): + task = self._get_task_record(task_id) + if task is None: + return + message = task.get('message') or {} + if message.get('cleanup_policy') == cleanup_policy: + return + message['cleanup_policy'] = cleanup_policy + self._update_task_record(task_id, message=to_json(message)) + + def _finalize_kb_deletion_if_empty(self, kb_id: str) -> bool: + with self._db_manager.get_session() as session: + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + if session.query(Rel).filter(Rel.kb_id == kb_id).count() > 0: + return False + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + AlgoRel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is not None: + session.delete(kb_row) + session.query(AlgoRel).filter(AlgoRel.kb_id == kb_id).delete() + return True + + def _prepare_kb_delete_items(self, kb_id: str) -> Dict[str, Any]: + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + binding = self._get_kb_algorithm(kb_id) + default_algo_id = binding['algo_id'] if binding is not None else '__default__' + items = [] + for doc_id in self._list_kb_doc_ids(kb_id): + snapshot = ( + self._get_parse_snapshot(doc_id, kb_id, default_algo_id) + or self._get_latest_parse_snapshot(doc_id, kb_id) + ) + if snapshot is None or snapshot.get('status') == DocStatus.DELETED.value: + items.append({'action': 'purge_local', 'doc_id': doc_id}) + continue + status = snapshot.get('status') + task_type = snapshot.get('task_type') + if status in (DocStatus.WAITING.value, DocStatus.WORKING.value): + if task_type == TaskType.DOC_DELETE.value and snapshot.get('current_task_id'): + items.append({ + 'action': 'reuse_delete_task', + 'doc_id': doc_id, + 'task_id': snapshot['current_task_id'], + }) + continue + raise DocServiceError( + 'E_STATE_CONFLICT', + f'cannot delete kb while doc {doc_id} task is {status}', + {'kb_id': kb_id, 'doc_id': doc_id, 'status': status, 'task_type': task_type}, + ) + items.append({ + 'action': 'enqueue_delete', + 'doc_id': doc_id, + 'algo_id': snapshot.get('algo_id') or default_algo_id, + }) + return {'kb': kb, 'items': items} + + def _assert_action_allowed(self, doc_id: str, kb_id: str, algo_id: str, action: str): + snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) + status = snapshot.get('status') if snapshot is not None else None + if status is None and action in ('add', 'upload'): + doc = self._get_doc(doc_id) + status = doc.get('upload_status') if doc else None + + if action in ('add', 'upload'): + if status in ( + DocStatus.WAITING.value, + DocStatus.WORKING.value, + DocStatus.DELETING.value, + DocStatus.SUCCESS.value, + ): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is {status}') + return + + if status == DocStatus.WORKING.value and action in ('reparse', 'delete', 'transfer', 'metadata'): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is WORKING') + if status == DocStatus.DELETING.value and action in ('reparse', 'delete', 'transfer', 'metadata'): + raise DocServiceError('E_STATE_CONFLICT', f'cannot {action} while state is DELETING') + + def _upsert_parse_snapshot( + self, + doc_id: str, + kb_id: str, + algo_id: str, + status: DocStatus, + task_type: Optional[TaskType] = None, + current_task_id: Optional[str] = None, + idempotency_key: Optional[str] = None, + priority: int = 0, + task_score: Optional[int] = None, + retry_count: int = 0, + max_retry: int = 3, + lease_owner: Optional[str] = None, + lease_until: Optional[datetime] = None, + error_code: Optional[str] = None, + error_msg: Optional[str] = None, + failed_stage: Optional[str] = None, + queued_at: Optional[datetime] = None, + started_at: Optional[datetime] = None, + finished_at: Optional[datetime] = None, + ): + now = datetime.now() + with self._db_manager.get_session() as session: + State = self._db_manager.get_table_orm_class(PARSE_STATE_TABLE_INFO['name']) + row = ( + session.query(State) + .filter(State.doc_id == doc_id, State.kb_id == kb_id, State.algo_id == algo_id) + .first() + ) + if row is None: + row = State( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=status.value, + task_type=task_type.value if task_type else None, + current_task_id=current_task_id, + idempotency_key=idempotency_key, + priority=priority, + task_score=task_score, + retry_count=retry_count, + max_retry=max_retry, + lease_owner=lease_owner, + lease_until=lease_until, + last_error_code=error_code, + last_error_msg=error_msg, + failed_stage=failed_stage, + queued_at=queued_at, + started_at=started_at, + finished_at=finished_at, + created_at=now, + updated_at=now, + ) + else: + row.status = status.value + if task_type is not None: + row.task_type = task_type.value + row.current_task_id = current_task_id + row.idempotency_key = idempotency_key + row.priority = priority + row.task_score = task_score + row.retry_count = retry_count + row.max_retry = max_retry + row.lease_owner = lease_owner + row.lease_until = lease_until + row.last_error_code = error_code + row.last_error_msg = error_msg + row.failed_stage = failed_stage + row.queued_at = queued_at + row.started_at = started_at + row.finished_at = finished_at + row.updated_at = now + session.add(row) + return self._get_parse_snapshot(doc_id, kb_id, algo_id) + + def _validate_unique_doc_ids(self, doc_ids: List[str], field_name: str = 'doc_id'): + duplicated = set() + seen = set() + for doc_id in doc_ids: + if doc_id in seen: + duplicated.add(doc_id) + seen.add(doc_id) + if duplicated: + duplicated_list = sorted(duplicated) + raise DocServiceError( + 'E_INVALID_PARAM', + f'duplicate {field_name} detected', + {f'duplicate_{field_name}s': duplicated_list}, + ) + + def _call_parser_client(self, method, *args, **kwargs): + try: + return method(*args, **kwargs) + except TypeError as exc: + compat_kwargs = dict(kwargs) + removed = False + for field in ('callback_url',): + if field in compat_kwargs and field in str(exc): + compat_kwargs.pop(field, None) + removed = True + if not removed: + raise + return method(*args, **compat_kwargs) + + def _create_parser_task(self, task_id: str, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, + file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + reparse_group: Optional[str] = None, parser_kb_id: Optional[str] = None, + transfer_params: Optional[Dict[str, Any]] = None): + if task_type in (TaskType.DOC_ADD, TaskType.DOC_TRANSFER): + if not file_path: + raise RuntimeError(f'file_path is required for task_type {task_type.value}') + task_resp = self._call_parser_client( + self._parser_client.add_doc, + task_id, algo_id, parser_kb_id or kb_id, doc_id, file_path, metadata, + callback_url=self._callback_url, transfer_params=transfer_params, + ) + elif task_type == TaskType.DOC_REPARSE: + if not file_path: + raise RuntimeError('file_path is required for reparse task') + task_resp = self._call_parser_client( + self._parser_client.add_doc, + task_id, algo_id, kb_id, doc_id, file_path, metadata, + reparse_group=reparse_group or 'all', callback_url=self._callback_url, + ) + elif task_type == TaskType.DOC_UPDATE_META: + task_resp = self._call_parser_client( + self._parser_client.update_meta, + task_id, algo_id, kb_id, doc_id, metadata, file_path, callback_url=self._callback_url, + ) + elif task_type == TaskType.DOC_DELETE: + task_resp = self._call_parser_client( + self._parser_client.delete_doc, + task_id, algo_id, kb_id, doc_id, callback_url=self._callback_url, + ) + else: + raise RuntimeError(f'unsupported task type: {task_type.value}') + if task_resp.code != 200: + raise RuntimeError(f'failed to enqueue parser task: {task_resp.msg}') + return task_id + + def _enqueue_task( + self, doc_id: str, kb_id: str, algo_id: str, task_type: TaskType, + idempotency_key: Optional[str] = None, priority: int = 0, + file_path: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None, + reparse_group: Optional[str] = None, cleanup_policy: Optional[str] = None, + parser_kb_id: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None, + extra_message: Optional[Dict[str, Any]] = None, parser_doc_id: Optional[str] = None, + ): + task_id = str(uuid4()) + task_message = { + 'doc_id': doc_id, + 'kb_id': kb_id, + 'algo_id': algo_id, + 'file_path': file_path, + 'metadata': metadata, + 'reparse_group': reparse_group, + } + if extra_message: + task_message.update(extra_message) + if cleanup_policy: + task_message['cleanup_policy'] = cleanup_policy + if transfer_params: + task_message['transfer_params'] = transfer_params + task_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING + self._create_task_record(task_id, task_type, doc_id, kb_id, algo_id, task_status, message=task_message) + parse_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WAITING + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=parse_status, + task_type=task_type, + current_task_id=task_id, + idempotency_key=idempotency_key, + priority=priority, + queued_at=datetime.now(), + started_at=None, + finished_at=None, + error_code=None, + error_msg=None, + failed_stage=None, + ) + try: + self._create_parser_task( + task_id, parser_doc_id or doc_id, kb_id, algo_id, task_type, + file_path=file_path, metadata=metadata, reparse_group=reparse_group, + parser_kb_id=parser_kb_id, transfer_params=transfer_params, + ) + except Exception as exc: + finished_at = datetime.now() + error_msg = str(exc) + self._update_task_record( + task_id, + status=DocStatus.FAILED.value, + error_code='PARSER_SUBMIT_FAILED', + error_msg=error_msg, + finished_at=finished_at, + ) + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=DocStatus.FAILED, + **self._build_snapshot_update( + self._get_parse_snapshot(doc_id, kb_id, algo_id), + task_type=task_type, + current_task_id=task_id, + error_code='PARSER_SUBMIT_FAILED', + error_msg=error_msg, + failed_stage='SUBMIT', + finished_at=finished_at, + ), + ) + self._apply_doc_upload_status(doc_id, task_type, DocStatus.FAILED) + raise + return task_id, self._get_parse_snapshot(doc_id, kb_id, algo_id) + + def _apply_doc_upload_status(self, doc_id: str, task_type: TaskType, status: DocStatus): + if task_type == TaskType.DOC_ADD: + if status in (DocStatus.WORKING, DocStatus.FAILED, DocStatus.CANCELED, DocStatus.SUCCESS): + self._set_doc_upload_status(doc_id, status) + return + if task_type == TaskType.DOC_DELETE: + if status == DocStatus.DELETING: + if self._doc_relation_count(doc_id) <= 1: + self._set_doc_upload_status(doc_id, DocStatus.DELETING) + return + if status == DocStatus.DELETED: + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + return + if status in (DocStatus.FAILED, DocStatus.CANCELED): + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + return + + def _prepare_upload_items(self, request: UploadRequest) -> List[Dict[str, Any]]: + prepared_items: List[Dict[str, Any]] = [] + for item in request.items: + file_path = item.file_path + if not os.path.exists(file_path): + raise DocServiceError('E_INVALID_PARAM', f'file not found: {file_path}') + prepared_items.append({ + 'file_path': file_path, + 'metadata': item.metadata, + 'doc_id': gen_doc_id(file_path, doc_id=item.doc_id), + 'filename': os.path.basename(file_path), + }) + + self._validate_unique_doc_ids([item['doc_id'] for item in prepared_items]) + + for item in prepared_items: + if self._has_kb_document(request.kb_id, item['doc_id']): + self._assert_action_allowed(item['doc_id'], request.kb_id, request.algo_id, 'upload') + return prepared_items + + def _prepare_reparse_items(self, request: ReparseRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for doc_id in request.doc_ids: + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'reparse') + prepared_items.append({ + 'doc_id': doc_id, + 'file_path': doc.get('path'), + 'metadata': from_json(doc.get('meta')), + }) + return prepared_items + + def _prepare_delete_items(self, request: DeleteRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for doc_id in request.doc_ids: + doc = self._get_doc(doc_id) + snapshot = ( + self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) + or self._get_latest_parse_snapshot(doc_id, request.kb_id) + ) + if doc is None or not self._has_kb_document(request.kb_id, doc_id): + if snapshot is not None and snapshot.get('status') in ( + DocStatus.DELETING.value, + DocStatus.DELETED.value, + ): + prepared_items.append({ + 'doc_id': doc_id, + 'action': 'noop', + 'status': snapshot['status'], + 'task_id': snapshot.get('current_task_id'), + }) + continue + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}') + if snapshot is not None and snapshot.get('status') in (DocStatus.DELETING.value, DocStatus.DELETED.value): + prepared_items.append({ + 'doc_id': doc_id, + 'action': 'noop', + 'status': snapshot['status'], + 'task_id': snapshot.get('current_task_id'), + }) + continue + self._assert_action_allowed(doc_id, request.kb_id, request.algo_id, 'delete') + prepared_items.append({'doc_id': doc_id, 'action': 'execute'}) + return prepared_items + + def _prepare_metadata_patch_items(self, request: MetadataPatchRequest) -> List[Dict[str, Any]]: + prepared_items = [] + for item in request.items: + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(request.kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._assert_action_allowed(item.doc_id, request.kb_id, request.algo_id, 'metadata') + merged = from_json(doc.get('meta')) + merged.update(item.patch) + prepared_items.append({'doc_id': item.doc_id, 'metadata': merged, 'file_path': doc.get('path')}) + return prepared_items + + def _prepare_transfer_items(self, request: TransferRequest) -> List[Dict[str, Any]]: + prepared_items = [] + seen_pairs = set() + seen_targets = set() + for item in request.items: + if item.mode not in ('move', 'copy'): + raise DocServiceError( + 'E_INVALID_PARAM', f'invalid transfer mode: {item.mode}', {'mode': item.mode} + ) + item_key = (item.doc_id, item.source_kb_id, item.target_kb_id) + if item_key in seen_pairs: + raise DocServiceError( + 'E_INVALID_PARAM', + 'duplicate transfer item detected', + {'doc_id': item.doc_id, 'source_kb_id': item.source_kb_id, 'target_kb_id': item.target_kb_id}, + ) + seen_pairs.add(item_key) + target_key = (item.target_doc_id, item.target_kb_id, item.target_algo_id) + if target_key in seen_targets: + raise DocServiceError( + 'E_INVALID_PARAM', + 'duplicate transfer target detected', + { + 'doc_id': item.target_doc_id, + 'target_kb_id': item.target_kb_id, + 'target_algo_id': item.target_algo_id, + }, + ) + seen_targets.add(target_key) + if item.source_algo_id != item.target_algo_id: + raise DocServiceError( + 'E_INVALID_PARAM', + 'transfer across different algorithms is not supported', + { + 'doc_id': item.doc_id, + 'source_algo_id': item.source_algo_id, + 'target_algo_id': item.target_algo_id, + }, + ) + doc = self._get_doc(item.doc_id) + if doc is None or not self._has_kb_document(item.source_kb_id, item.doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {item.doc_id}') + self._validate_kb_algorithm(item.source_kb_id, item.source_algo_id) + self._validate_kb_algorithm(item.target_kb_id, item.target_algo_id) + self._assert_action_allowed(item.doc_id, item.source_kb_id, item.source_algo_id, 'transfer') + if self._has_kb_document(item.target_kb_id, item.target_doc_id): + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc already exists in target kb: {item.target_doc_id}', + {'doc_id': item.target_doc_id, 'target_kb_id': item.target_kb_id}, + ) + source_snapshot = self._get_parse_snapshot(item.doc_id, item.source_kb_id, item.source_algo_id) + if source_snapshot is None or source_snapshot.get('status') != DocStatus.SUCCESS.value: + raise DocServiceError( + 'E_STATE_CONFLICT', + f'doc transfer requires source parse status SUCCESS: {item.doc_id}', + { + 'doc_id': item.doc_id, + 'source_kb_id': item.source_kb_id, + 'source_algo_id': item.source_algo_id, + 'status': source_snapshot.get('status') if source_snapshot else None, + }, + ) + source_metadata = from_json(doc.get('meta')) + target_metadata = merge_transfer_metadata(source_metadata, item.target_metadata) + target_path = resolve_transfer_target_path(doc.get('path'), item.target_filename, item.target_file_path) + target_filename = os.path.basename(target_path) + prepared_items.append({ + 'doc_id': item.doc_id, + 'target_doc_id': item.target_doc_id, + 'kb_id': item.kb_id, + 'algo_id': item.algo_id, + 'source_kb_id': item.source_kb_id, + 'source_algo_id': item.source_algo_id, + 'target_kb_id': item.target_kb_id, + 'target_algo_id': item.target_algo_id, + 'mode': item.mode, + 'file_path': doc.get('path'), + 'target_file_path': target_path, + 'target_filename': target_filename, + 'metadata': target_metadata, + 'source_type': SourceType(doc.get('source_type')), + 'transfer_params': { + 'mode': 'mv' if item.mode == 'move' else 'cp', + 'target_algo_id': item.target_algo_id, + 'target_doc_id': item.target_doc_id, + 'target_kb_id': item.target_kb_id, + }, + }) + return prepared_items + + def upload(self, request: UploadRequest) -> List[Dict[str, Any]]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + prepared_items = self._prepare_upload_items(request) + source_type = request.source_type or SourceType.API + items: List[Dict[str, Any]] = [] + for item in prepared_items: + doc_id = item['doc_id'] + file_path = item['file_path'] + metadata = item['metadata'] + doc = self._upsert_doc( + doc_id=doc_id, + filename=item['filename'], + path=file_path, + metadata=metadata, + source_type=source_type, + upload_status=DocStatus.SUCCESS, + ) + self._ensure_kb_document(request.kb_id, doc_id) + try: + task_id, snapshot = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_ADD, + idempotency_key=request.idempotency_key, + file_path=file_path, + metadata=metadata, + ) + error_code = None + error_msg = None + accepted = True + except Exception as exc: + snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) or {} + doc = self._get_doc(doc_id) or doc + task_id = snapshot.get('current_task_id') + error_code = snapshot.get('last_error_code') or type(exc).__name__ + error_msg = snapshot.get('last_error_msg') or str(exc) + accepted = False + items.append({ + 'doc_id': doc_id, + 'kb_id': request.kb_id, + 'algo_id': request.algo_id, + 'upload_status': doc['upload_status'], + 'parse_status': snapshot.get('status', DocStatus.FAILED.value), + 'task_id': task_id, + 'accepted': accepted, + 'error_code': error_code, + 'error_msg': error_msg, + }) + return items + + def add_files(self, request: AddRequest) -> List[Dict[str, Any]]: + return self.upload(UploadRequest( + items=request.items, + kb_id=request.kb_id, + algo_id=request.algo_id, + source_type=request.source_type or SourceType.EXTERNAL, + idempotency_key=request.idempotency_key, + )) + + def reparse(self, request: ReparseRequest) -> List[str]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') + prepared_items = self._prepare_reparse_items(request) + task_ids = [] + for item in prepared_items: + task_id, _ = self._enqueue_task( + item['doc_id'], request.kb_id, request.algo_id, TaskType.DOC_REPARSE, + idempotency_key=request.idempotency_key, + file_path=item['file_path'], + metadata=item['metadata'], + reparse_group='all', + ) + task_ids.append(task_id) + return task_ids + + def delete(self, request: DeleteRequest) -> List[Dict[str, Any]]: + self._validate_kb_algorithm(request.kb_id, request.algo_id) + self._validate_unique_doc_ids(request.doc_ids, field_name='doc_id') + prepared_items = self._prepare_delete_items(request) + items: List[Dict[str, Any]] = [] + for item in prepared_items: + doc_id = item['doc_id'] + if item.get('action') == 'noop': + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': item.get('task_id'), + 'status': item['status'], + 'error_code': None, + }) + continue + snapshot = self._get_parse_snapshot(doc_id, request.kb_id, request.algo_id) + if ( + snapshot is not None + and snapshot.get('status') == DocStatus.WAITING.value + and snapshot.get('task_type') == TaskType.DOC_ADD.value + and snapshot.get('current_task_id') + ): + cancel_resp = self.cancel_task(snapshot['current_task_id']) + if cancel_resp.code != 200: + raise DocServiceError( + 'E_STATE_CONFLICT', + cancel_resp.msg, + ( + cancel_resp.data + if isinstance(cancel_resp.data, dict) + else {'task_id': snapshot['current_task_id']} + ), + ) + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': snapshot['current_task_id'], + 'status': DocStatus.CANCELED.value, + 'error_code': None, + }) + continue + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == doc_id).first() + if self._doc_relation_count(doc_id) <= 1: + row.upload_status = DocStatus.DELETING.value + row.updated_at = datetime.now() + session.add(row) + + task_id, snapshot = self._enqueue_task( + doc_id, request.kb_id, request.algo_id, TaskType.DOC_DELETE, + idempotency_key=request.idempotency_key, + ) + items.append({ + 'doc_id': doc_id, + 'accepted': True, + 'task_id': task_id, + 'status': snapshot['status'], + 'error_code': None, + }) + return items + + def transfer(self, request: TransferRequest) -> List[Dict[str, Any]]: + prepared_items = self._prepare_transfer_items(request) + items: List[Dict[str, Any]] = [] + for item in prepared_items: + task_id = None + try: + self._upsert_doc( + doc_id=item['target_doc_id'], + filename=item['target_filename'], + path=item['target_file_path'], + metadata=item['metadata'], + source_type=item['source_type'], + upload_status=DocStatus.SUCCESS, + allowed_path_doc_ids={item['doc_id']}, + ) + self._ensure_kb_document(item['target_kb_id'], item['target_doc_id']) + task_id, snapshot = self._enqueue_task( + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'], TaskType.DOC_TRANSFER, + idempotency_key=request.idempotency_key, + file_path=item['file_path'], + metadata=item['metadata'], + parser_kb_id=item['source_kb_id'], + transfer_params=item['transfer_params'], + extra_message={ + 'source_doc_id': item['doc_id'], + 'source_kb_id': item['source_kb_id'], + 'source_algo_id': item['source_algo_id'], + 'target_kb_id': item['target_kb_id'], + 'target_algo_id': item['target_algo_id'], + 'target_doc_id': item['target_doc_id'], + 'mode': item['mode'], + }, + parser_doc_id=item['doc_id'], + ) + error_code = None + error_msg = None + accepted = True + except Exception as exc: + snapshot = self._get_parse_snapshot( + item['target_doc_id'], item['target_kb_id'], item['target_algo_id'] + ) or {} + task_id = task_id or snapshot.get('current_task_id') + error_code = snapshot.get('last_error_code') + if not error_code: + error_code = exc.biz_code if isinstance(exc, DocServiceError) else type(exc).__name__ + error_msg = snapshot.get('last_error_msg') or ( + exc.msg if isinstance(exc, DocServiceError) else str(exc) + ) + accepted = False + items.append({ + 'doc_id': item['doc_id'], + 'target_doc_id': item['target_doc_id'], + 'task_id': task_id, + 'kb_id': item['kb_id'], + 'algo_id': item['algo_id'], + 'source_kb_id': item['source_kb_id'], + 'target_kb_id': item['target_kb_id'], + 'source_algo_id': item['source_algo_id'], + 'target_algo_id': item['target_algo_id'], + 'mode': item['mode'], + 'target_file_path': item['target_file_path'], + 'status': snapshot.get('status', DocStatus.FAILED.value), + 'accepted': accepted, + 'error_code': error_code, + 'error_msg': error_msg, + }) + return items + + def patch_metadata(self, request: MetadataPatchRequest): + self._validate_kb_algorithm(request.kb_id, request.algo_id) + updated = [] + failed = [] + self._validate_unique_doc_ids([item.doc_id for item in request.items], field_name='doc_id') + prepared_items = self._prepare_metadata_patch_items(request) + for item in prepared_items: + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + row = session.query(Doc).filter(Doc.doc_id == item['doc_id']).first() + row.meta = to_json(item['metadata']) + row.updated_at = datetime.now() + session.add(row) + + task_id, _ = self._enqueue_task( + item['doc_id'], request.kb_id, request.algo_id, TaskType.DOC_UPDATE_META, + idempotency_key=request.idempotency_key, + file_path=item['file_path'], + metadata=item['metadata'], + ) + updated.append({'doc_id': item['doc_id'], 'task_id': task_id}) + return { + 'updated_count': len(updated), + 'doc_ids': [u['doc_id'] for u in updated], + 'failed_items': failed, + 'items': updated, + } + + def _sync_doc_upload_status(self, doc_id: str): + target = DocStatus.SUCCESS if self._doc_relation_count(doc_id) > 0 else DocStatus.DELETED + self._set_doc_upload_status(doc_id, target) + + @staticmethod + def _build_snapshot_update(snapshot: Optional[Dict[str, Any]], **overrides): + snapshot = snapshot or {} + data = { + 'task_type': TaskType(snapshot['task_type']) if snapshot.get('task_type') else None, + 'current_task_id': snapshot.get('current_task_id'), + 'idempotency_key': snapshot.get('idempotency_key'), + 'priority': snapshot.get('priority', 0), + 'task_score': snapshot.get('task_score'), + 'retry_count': snapshot.get('retry_count', 0), + 'max_retry': snapshot.get('max_retry', 3), + 'lease_owner': snapshot.get('lease_owner'), + 'lease_until': snapshot.get('lease_until'), + 'error_code': snapshot.get('last_error_code'), + 'error_msg': snapshot.get('last_error_msg'), + 'failed_stage': snapshot.get('failed_stage'), + 'queued_at': snapshot.get('queued_at'), + 'started_at': snapshot.get('started_at'), + 'finished_at': snapshot.get('finished_at'), + } + data.update(overrides) + return data + + def _resolve_callback_task(self, callback: TaskCallbackRequest): + task = self._get_task_record(callback.task_id) + if task is not None: + return task + payload = callback.payload or {} + required_fields = {'task_type', 'doc_id', 'kb_id', 'algo_id'} + if required_fields.issubset(payload.keys()): + return { + 'task_id': callback.task_id, + 'task_type': payload['task_type'], + 'doc_id': payload['doc_id'], + 'kb_id': payload['kb_id'], + 'algo_id': payload['algo_id'], + } + return None + + def on_task_callback(self, callback: TaskCallbackRequest): # noqa: C901 + if not self._record_callback(callback.callback_id, callback.task_id): + return {'ack': True, 'deduped': True, 'ignored_reason': None} + try: + task_data = self._resolve_callback_task(callback) + if task_data is None: + return {'ack': True, 'ignored_reason': 'task_not_found'} + doc_id = task_data['doc_id'] + kb_id = task_data['kb_id'] + algo_id = task_data['algo_id'] + task_type = TaskType(task_data['task_type']) + snapshot = self._get_parse_snapshot(doc_id, kb_id, algo_id) + if snapshot and snapshot.get('current_task_id') and snapshot['current_task_id'] != callback.task_id: + return {'ack': True, 'deduped': False, 'ignored_reason': 'stale_task_callback'} + if task_data.get('status') == DocStatus.CANCELED.value and callback.status != DocStatus.CANCELED: + return {'ack': True, 'deduped': False, 'ignored_reason': 'canceled_task_callback'} + + if callback.event_type == CallbackEventType.START: + self._update_task_record( + callback.task_id, + status=DocStatus.WORKING.value, + started_at=datetime.now(), + finished_at=None, + error_code=None, + error_msg=None, + ) + start_status = DocStatus.DELETING if task_type == TaskType.DOC_DELETE else DocStatus.WORKING + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=start_status, + **self._build_snapshot_update( + snapshot, + task_type=task_type, + current_task_id=callback.task_id, + started_at=datetime.now(), + finished_at=None, + error_code=None, + error_msg=None, + failed_stage=None, + ), + ) + if task_type == TaskType.DOC_ADD: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.WORKING) + elif task_type == TaskType.DOC_DELETE: + self._apply_doc_upload_status(doc_id, task_type, DocStatus.DELETING) + return {'ack': True, 'deduped': False, 'ignored_reason': None} + + final_status = callback.status + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.SUCCESS: + final_status = DocStatus.DELETED + failed_stage = None + if final_status == DocStatus.FAILED: + failed_stage = 'DELETE' if task_type == TaskType.DOC_DELETE else 'PARSE' + task_message = task_data.get('message') or {} + cleanup_policy = task_message.get('cleanup_policy') + + if ( + task_type == TaskType.DOC_TRANSFER + and final_status == DocStatus.SUCCESS + and task_message.get('mode') == 'move' + ): + source_kb_id = task_message.get('source_kb_id') + source_doc_id = task_message.get('source_doc_id') or doc_id + if source_kb_id and source_kb_id != kb_id: + self._remove_kb_document(source_kb_id, source_doc_id) + self._delete_parse_snapshots(source_doc_id, source_kb_id) + self._sync_doc_upload_status(source_doc_id) + + self._update_task_record( + callback.task_id, + status=final_status.value, + error_code=callback.error_code, + error_msg=callback.error_msg, + finished_at=datetime.now(), + ) + + self._upsert_parse_snapshot( + doc_id=doc_id, + kb_id=kb_id, + algo_id=algo_id, + status=final_status, + **self._build_snapshot_update( + snapshot, + task_type=task_type, + current_task_id=callback.task_id, + error_code=callback.error_code, + error_msg=callback.error_msg, + failed_stage=failed_stage, + finished_at=datetime.now(), + ), + ) + + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED: + self._remove_kb_document(kb_id, doc_id) + self._apply_doc_upload_status(doc_id, task_type, final_status) + if task_type == TaskType.DOC_DELETE and final_status == DocStatus.DELETED and cleanup_policy == 'purge': + self._purge_deleted_kb_doc_data(kb_id, doc_id) + self._finalize_kb_deletion_if_empty(kb_id) + + return {'ack': True, 'deduped': False, 'ignored_reason': None} + except Exception: + self._forget_callback_record(callback.callback_id, callback.task_id) + raise + + def list_docs( + self, + status: Optional[List[str]] = None, + kb_id: Optional[str] = None, + algo_id: Optional[str] = None, + keyword: Optional[str] = None, + include_deleted_or_canceled: bool = True, + page: int = 1, + page_size: int = 20, + ): + page = max(page, 1) + page_size = max(1, page_size) + with self._db_manager.get_session() as session: + Doc = self._db_manager.get_table_orm_class(DOCUMENTS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_DOCUMENTS_TABLE_INFO['name']) + query = session.query(Doc, Rel).join(Rel, Doc.doc_id == Rel.doc_id) + + if kb_id: + query = query.filter(Rel.kb_id == kb_id) + if keyword: + like_expr = f'%{keyword}%' + query = query.filter((Doc.filename.like(like_expr)) | (Doc.path.like(like_expr))) + if not include_deleted_or_canceled: + query = query.filter(~Doc.upload_status.in_([DocStatus.DELETED.value, DocStatus.CANCELED.value])) + + rows = query.order_by(Rel.updated_at.desc(), Doc.updated_at.desc()).all() + items = [] + for doc_row, rel_row in rows: + doc = _orm_to_dict(doc_row) + relation = _orm_to_dict(rel_row) + snapshot = ( + self._get_parse_snapshot(doc['doc_id'], relation['kb_id'], algo_id) + if algo_id else + self._get_latest_parse_snapshot(doc['doc_id'], relation['kb_id']) + ) + if status and (snapshot is None or snapshot.get('status') not in status): + continue + doc['metadata'] = from_json(doc.get('meta')) + items.append({'doc': doc, 'relation': relation, 'snapshot': snapshot}) + total = len(items) + page_items = items[(page - 1) * page_size:page * page_size] + return {'items': page_items, 'total': total, 'page': page, 'page_size': page_size} + + def get_doc_detail(self, doc_id: str): + doc = self._get_doc(doc_id) + if not doc: + raise DocServiceError('E_NOT_FOUND', f'doc not found: {doc_id}', {'doc_id': doc_id}) + + doc['metadata'] = from_json(doc.get('meta')) + snapshots = self.list_docs(page=1, page_size=2000)['items'] + matched_items = [] + for item in snapshots: + if item['doc']['doc_id'] == doc_id: + matched_items.append(item) + relation = matched_items[0].get('relation') if matched_items else None + snapshot = matched_items[0].get('snapshot') if matched_items else None + latest_task = None + if snapshot and snapshot.get('current_task_id'): + latest_task = self._get_task_record(snapshot['current_task_id']) + return { + 'doc': doc, + 'relation': relation, + 'snapshot': snapshot, + 'latest_task': latest_task, + 'relations': [item.get('relation') for item in matched_items], + 'snapshots': [item.get('snapshot') for item in matched_items if item.get('snapshot') is not None], + } + + def list_tasks(self, status: Optional[List[str]], page: int, page_size: int): + parser_list_tasks = getattr(self._parser_client, 'list_tasks', None) + if callable(parser_list_tasks): + try: + return parser_list_tasks(status=status, page=page, page_size=page_size) + except Exception as exc: + LOG.warning(f'[DocService] Fallback to local task list: {exc}') + page = max(page, 1) + page_size = max(page_size, 1) + with self._db_manager.get_session() as session: + Task = self._db_manager.get_table_orm_class(DOC_SERVICE_TASKS_TABLE_INFO['name']) + query = session.query(Task) + if status: + query = query.filter(Task.status.in_(status)) + total = query.count() + rows = ( + query.order_by(Task.created_at.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + items = [] + for row in rows: + task = _orm_to_dict(row) + task['message'] = from_json(task.get('message')) + items.append(task) + return BaseResponse(code=200, msg='success', data={ + 'items': items, + 'total': total, + 'page': page, + 'page_size': page_size, + }) + + def get_task(self, task_id: str): + parser_get_task = getattr(self._parser_client, 'get_task', None) + if callable(parser_get_task): + try: + return parser_get_task(task_id) + except Exception as exc: + LOG.warning(f'[DocService] Fallback to local task query: {exc}') + task = self._get_task_record(task_id) + if task is None: + return BaseResponse(code=404, msg='task not found', data=None) + return BaseResponse(code=200, msg='success', data=task) + + def get_tasks_batch(self, task_ids: List[str]): + items = [] + for task_id in task_ids: + resp = self.get_task(task_id) + if resp.code == 200 and resp.data is not None: + items.append(resp.data) + return {'items': items} + + def cancel_task(self, task_id: str): + task = self._get_task_record(task_id) + if task is None: + return BaseResponse(code=404, msg='task not found', data={'task_id': task_id, 'cancel_status': False}) + if task.get('status') != DocStatus.WAITING.value: + return BaseResponse( + code=409, + msg='task cannot be canceled', + data={'task_id': task_id, 'cancel_status': False, 'status': task.get('status')}, + ) + resp = self._parser_client.cancel_task(task_id) + if resp.code != 200: + return resp + resp_data = resp.data or {} + if not resp_data.get('cancel_status'): + return BaseResponse( + code=409, + msg=resp_data.get('message', 'task cannot be canceled'), + data={'task_id': task_id, 'cancel_status': False, 'status': task.get('status')}, + ) + self.on_task_callback(TaskCallbackRequest( + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.CANCELED, + )) + return BaseResponse( + code=200, + msg='success', + data={'task_id': task_id, 'cancel_status': True, 'status': DocStatus.CANCELED.value}, + ) + + def list_algorithms(self): + resp = self._parser_client.list_algorithms() + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + return resp.data + + def get_algo_groups(self, algo_id: str): + resp = self._parser_client.get_algorithm_groups(algo_id) + if resp.code == 404: + raise fastapi.HTTPException(status_code=404, detail='algo not found') + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + return resp.data + + def list_algorithms_compat(self): + items = self.list_algorithms() + return {'items': items} + + def get_algorithm_info(self, algo_id: str): + algorithms = self.list_algorithms() + for item in algorithms: + if item['algo_id'] == algo_id: + data = dict(item) + data['groups'] = self.get_algo_groups(algo_id) + return data + raise DocServiceError('E_NOT_FOUND', f'algo not found: {algo_id}') + + def list_chunks( + self, + kb_id: str, + doc_id: str, + group: str, + algo_id: str = '__default__', + page: int = 1, + page_size: int = 20, + offset: Optional[int] = None, + ): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required', {'kb_id': kb_id}) + if not doc_id: + raise DocServiceError('E_INVALID_PARAM', 'doc_id is required', {'doc_id': doc_id}) + if not group: + raise DocServiceError('E_INVALID_PARAM', 'group is required', {'group': group}) + self._validate_kb_algorithm(kb_id, algo_id) + doc = self._get_doc(doc_id) + if doc is None or not self._has_kb_document(kb_id, doc_id): + raise DocServiceError('E_NOT_FOUND', f'doc not found in kb: {doc_id}', {'kb_id': kb_id, 'doc_id': doc_id}) + groups = self.get_algo_groups(algo_id) + if not any(item.get('name') == group for item in groups): + raise DocServiceError( + 'E_INVALID_PARAM', + f'invalid group: {group}', + {'algo_id': algo_id, 'group': group}, + ) + page = max(page, 1) + page_size = max(page_size, 1) + offset = (page - 1) * page_size if offset is None else max(offset, 0) + resp = self._parser_client.list_doc_chunks( + algo_id=algo_id, + kb_id=kb_id, + doc_id=doc_id, + group=group, + offset=offset, + page_size=page_size, + ) + if resp.code == 404: + raise DocServiceError('E_NOT_FOUND', resp.msg, {'kb_id': kb_id, 'doc_id': doc_id, 'group': group}) + if resp.code == 400: + raise DocServiceError('E_INVALID_PARAM', resp.msg, {'kb_id': kb_id, 'doc_id': doc_id, 'group': group}) + if resp.code != 200: + raise fastapi.HTTPException(status_code=502, detail=resp.msg) + data = dict(resp.data or {}) + data['page'] = page + data['page_size'] = page_size + data['offset'] = offset + data.setdefault('items', []) + data.setdefault('total', 0) + return data + + def health(self): + parser_ok = False + try: + parser_ok = self._parser_client.health().code == 200 + except Exception as exc: + LOG.warning(f'[DocService] Parser health check failed: {exc}') + return { + 'status': 'ok', + 'version': 'v1', + 'deps': { + 'sql': True, + 'parser': parser_ok, + }, + } + + def list_kbs( + self, + page: int = 1, + page_size: int = 20, + keyword: Optional[str] = None, + status: Optional[List[str]] = None, + owner_id: Optional[str] = None, + ): + page = max(page, 1) + page_size = max(page_size, 1) + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + query = session.query(Kb, Rel).outerjoin(Rel, Rel.kb_id == Kb.kb_id) + query = query.filter(Kb.kb_id != '__default__') + if keyword: + like_expr = f'%{keyword}%' + query = query.filter( + sqlalchemy.or_( + Kb.kb_id.like(like_expr), + Kb.display_name.like(like_expr), + Kb.description.like(like_expr), + ) + ) + if status: + query = query.filter(Kb.status.in_(status)) + if owner_id: + query = query.filter(Kb.owner_id == owner_id) + total = query.count() + rows = ( + query.order_by(Kb.updated_at.desc(), Kb.created_at.desc()) + .offset((page - 1) * page_size) + .limit(page_size) + .all() + ) + items = [self._build_kb_data(kb_row, algo_row) for kb_row, algo_row in rows] + return {'items': items, 'total': total, 'page': page, 'page_size': page_size} + + def get_kb(self, kb_id: str): + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + row = ( + session.query(Kb, Rel) + .outerjoin(Rel, Rel.kb_id == Kb.kb_id) + .filter(Kb.kb_id == kb_id) + .first() + ) + if row is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + kb_row, algo_row = row + return self._build_kb_data(kb_row, algo_row) + + def batch_get_kbs(self, kb_ids: List[str]): + if not kb_ids: + raise DocServiceError('E_INVALID_PARAM', 'kb_ids is required', {'kb_ids': kb_ids}) + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + Rel = self._db_manager.get_table_orm_class(KB_ALGORITHM_TABLE_INFO['name']) + rows = ( + session.query(Kb, Rel) + .outerjoin(Rel, Rel.kb_id == Kb.kb_id) + .filter(Kb.kb_id.in_(kb_ids)) + .all() + ) + row_map = {kb_row.kb_id: self._build_kb_data(kb_row, algo_row) for kb_row, algo_row in rows} + items = [] + missing_kb_ids = [] + for kb_id in kb_ids: + if kb_id in row_map: + items.append(row_map[kb_id]) + else: + missing_kb_ids.append(kb_id) + return {'items': items, 'missing_kb_ids': missing_kb_ids} + + def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + if self._get_kb(kb_id) is not None: + raise DocServiceError('E_STATE_CONFLICT', f'kb already exists: {kb_id}', {'kb_id': kb_id}) + self._ensure_algorithm_exists(algo_id) + self._ensure_kb( + kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta, + ) + self._ensure_kb_algorithm(kb_id, algo_id) + return self.get_kb(kb_id) + + def update_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: Optional[str] = None, explicit_fields: Optional[Set[str]] = None): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + kb = self._get_kb(kb_id) + if kb is None: + raise DocServiceError('E_NOT_FOUND', f'kb not found: {kb_id}', {'kb_id': kb_id}) + explicit_fields = explicit_fields or set() + if 'algo_id' in explicit_fields: + if algo_id is None: + raise DocServiceError('E_INVALID_PARAM', 'algo_id cannot be null', {'kb_id': kb_id}) + self._ensure_algorithm_exists(algo_id) + if algo_id is not None: + binding = self._get_kb_algorithm(kb_id) + if binding is None: + self._ensure_kb_algorithm(kb_id, algo_id) + elif binding['algo_id'] != algo_id: + raise DocServiceError( + 'E_STATE_CONFLICT', f'kb {kb_id} is already bound to algorithm {binding["algo_id"]}', + {'kb_id': kb_id, 'bound_algo_id': binding['algo_id'], 'requested_algo_id': algo_id} + ) + self._ensure_kb( + kb_id, display_name=display_name, description=description, owner_id=owner_id, meta=meta, + update_fields=explicit_fields & {'display_name', 'description', 'owner_id', 'meta'}, + ) + return self.get_kb(kb_id) + + def delete_kb(self, kb_id: str): + if not kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + prepared = self._prepare_kb_delete_items(kb_id) + task_ids = [] + for item in prepared['items']: + if item['action'] == 'purge_local': + self._purge_deleted_kb_doc_data(kb_id, item['doc_id'], remove_relation=True) + continue + if item['action'] == 'reuse_delete_task': + self._mark_task_cleanup_policy(item['task_id'], 'purge') + task_ids.append(item['task_id']) + continue + if item['action'] == 'enqueue_delete': + task_id, _ = self._enqueue_task( + item['doc_id'], kb_id, item['algo_id'], TaskType.DOC_DELETE, cleanup_policy='purge' + ) + task_ids.append(task_id) + continue + raise RuntimeError(f'unsupported kb delete action: {item["action"]}') + new_status = KBStatus.DELETING.value if task_ids else KBStatus.DELETED.value + if task_ids: + with self._db_manager.get_session() as session: + Kb = self._db_manager.get_table_orm_class(KBS_TABLE_INFO['name']) + kb_row = session.query(Kb).filter(Kb.kb_id == kb_id).first() + if kb_row is not None: + kb_row.status = new_status + kb_row.updated_at = datetime.now() + session.add(kb_row) + else: + self._finalize_kb_deletion_if_empty(kb_id) + return {'kb_id': kb_id, 'status': new_status, 'task_ids': task_ids} + + def delete_kbs(self, kb_ids: List[str]): + if not kb_ids: + raise DocServiceError('E_INVALID_PARAM', 'kb_ids is required', {'kb_ids': kb_ids}) + items = [] + for kb_id in kb_ids: + items.append(self.delete_kb(kb_id)) + return {'items': items} diff --git a/lazyllm/tools/rag/doc_service/doc_server.openapi.json b/lazyllm/tools/rag/doc_service/doc_server.openapi.json new file mode 100644 index 000000000..0577f9e01 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_server.openapi.json @@ -0,0 +1,2078 @@ +{ + "components": { + "schemas": { + "AddFileItem": { + "properties": { + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "file_path": { + "title": "File Path", + "type": "string" + }, + "metadata": { + "additionalProperties": true, + "title": "Metadata", + "type": "object" + } + }, + "required": [ + "file_path" + ], + "title": "AddFileItem", + "type": "object" + }, + "AlgorithmInfoRequest": { + "properties": { + "algo_id": { + "title": "Algo Id", + "type": "string" + } + }, + "required": [ + "algo_id" + ], + "title": "AlgorithmInfoRequest", + "type": "object" + }, + "Body_upload_v1_docs_upload_post": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "files": { + "items": { + "contentMediaType": "application/octet-stream", + "type": "string" + }, + "title": "Files", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + }, + "source_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/SourceType" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "files" + ], + "title": "Body_upload_v1_docs_upload_post", + "type": "object" + }, + "CallbackEventType": { + "enum": [ + "START", + "FINISH" + ], + "title": "CallbackEventType", + "type": "string" + }, + "DeleteRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_ids": { + "items": { + "type": "string" + }, + "title": "Doc Ids", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "doc_ids" + ], + "title": "DeleteRequest", + "type": "object" + }, + "DocItemsRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/AddFileItem" + }, + "title": "Items", + "type": "array" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + }, + "source_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/SourceType" + }, + { + "type": "null" + } + ] + } + }, + "required": [ + "items" + ], + "title": "DocItemsRequest", + "type": "object" + }, + "DocStatus": { + "enum": [ + "WAITING", + "WORKING", + "SUCCESS", + "FAILED", + "CANCELED", + "DELETING", + "DELETED" + ], + "title": "DocStatus", + "type": "string" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { + "$ref": "#/components/schemas/ValidationError" + }, + "title": "Detail", + "type": "array" + } + }, + "title": "HTTPValidationError", + "type": "object" + }, + "KbBatchQueryRequest": { + "properties": { + "kb_ids": { + "items": { + "type": "string" + }, + "title": "Kb Ids", + "type": "array" + } + }, + "required": [ + "kb_ids" + ], + "title": "KbBatchQueryRequest", + "type": "object" + }, + "KbCreateRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Display Name" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "title": "Kb Id", + "type": "string" + }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Meta" + }, + "owner_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + }, + "required": [ + "kb_id" + ], + "title": "KbCreateRequest", + "type": "object" + }, + "KbDeleteBatchRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_ids": { + "items": { + "type": "string" + }, + "title": "Kb Ids", + "type": "array" + } + }, + "required": [ + "kb_ids" + ], + "title": "KbDeleteBatchRequest", + "type": "object" + }, + "KbUpdateRequest": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "description": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Description" + }, + "display_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Display Name" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "meta": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Meta" + }, + "owner_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + }, + "title": "KbUpdateRequest", + "type": "object" + }, + "MetadataPatchItem": { + "properties": { + "doc_id": { + "title": "Doc Id", + "type": "string" + }, + "patch": { + "additionalProperties": true, + "title": "Patch", + "type": "object" + } + }, + "required": [ + "doc_id" + ], + "title": "MetadataPatchItem", + "type": "object" + }, + "MetadataPatchRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/MetadataPatchItem" + }, + "title": "Items", + "type": "array" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "items" + ], + "title": "MetadataPatchRequest", + "type": "object" + }, + "ReparseRequest": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_ids": { + "items": { + "type": "string" + }, + "title": "Doc Ids", + "type": "array" + }, + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + } + }, + "required": [ + "doc_ids" + ], + "title": "ReparseRequest", + "type": "object" + }, + "SourceType": { + "enum": [ + "API", + "SCAN", + "TEMP", + "EXTERNAL" + ], + "title": "SourceType", + "type": "string" + }, + "TaskBatchRequest": { + "properties": { + "task_ids": { + "items": { + "type": "string" + }, + "title": "Task Ids", + "type": "array" + } + }, + "required": [ + "task_ids" + ], + "title": "TaskBatchRequest", + "type": "object" + }, + "TaskCallbackPayload": { + "properties": { + "algo_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + }, + "callback_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Callback Id" + }, + "doc_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Doc Id" + }, + "error_code": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Code" + }, + "error_msg": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Error Msg" + }, + "event_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/CallbackEventType" + }, + { + "type": "null" + } + ] + }, + "kb_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + }, + "payload": { + "additionalProperties": true, + "title": "Payload", + "type": "object" + }, + "status": { + "anyOf": [ + { + "$ref": "#/components/schemas/DocStatus" + }, + { + "type": "null" + } + ] + }, + "task_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Task Id" + }, + "task_status": { + "anyOf": [ + { + "$ref": "#/components/schemas/DocStatus" + }, + { + "type": "null" + } + ] + }, + "task_type": { + "anyOf": [ + { + "$ref": "#/components/schemas/TaskType" + }, + { + "type": "null" + } + ] + } + }, + "title": "TaskCallbackPayload", + "type": "object" + }, + "TaskCancelRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "task_id": { + "title": "Task Id", + "type": "string" + } + }, + "required": [ + "task_id" + ], + "title": "TaskCancelRequest", + "type": "object" + }, + "TaskInfoRequest": { + "properties": { + "task_id": { + "title": "Task Id", + "type": "string" + } + }, + "required": [ + "task_id" + ], + "title": "TaskInfoRequest", + "type": "object" + }, + "TaskType": { + "enum": [ + "DOC_ADD", + "DOC_DELETE", + "DOC_UPDATE_META", + "DOC_REPARSE", + "DOC_TRANSFER" + ], + "title": "TaskType", + "type": "string" + }, + "TransferItem": { + "properties": { + "algo_id": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + }, + "doc_id": { + "title": "Doc Id", + "type": "string" + }, + "kb_id": { + "default": "__default__", + "title": "Kb Id", + "type": "string" + }, + "mode": { + "default": "copy", + "title": "Mode", + "type": "string" + }, + "target_algo_id": { + "title": "Target Algo Id", + "type": "string" + }, + "target_doc_id": { + "title": "Target Doc Id", + "type": "string" + }, + "target_file_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Target File Path" + }, + "target_filename": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Target Filename" + }, + "target_kb_id": { + "title": "Target Kb Id", + "type": "string" + }, + "target_metadata": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Target Metadata" + } + }, + "required": [ + "doc_id", + "target_doc_id", + "target_kb_id", + "target_algo_id" + ], + "title": "TransferItem", + "type": "object" + }, + "TransferRequest": { + "properties": { + "idempotency_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + }, + "items": { + "items": { + "$ref": "#/components/schemas/TransferItem" + }, + "title": "Items", + "type": "array" + } + }, + "required": [ + "items" + ], + "title": "TransferRequest", + "type": "object" + }, + "ValidationError": { + "properties": { + "ctx": { + "title": "Context", + "type": "object" + }, + "input": { + "title": "Input" + }, + "loc": { + "items": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "integer" + } + ] + }, + "title": "Location", + "type": "array" + }, + "msg": { + "title": "Message", + "type": "string" + }, + "type": { + "title": "Error Type", + "type": "string" + } + }, + "required": [ + "loc", + "msg", + "type" + ], + "title": "ValidationError", + "type": "object" + } + } + }, + "info": { + "description": "OpenAPI schema generated from current DocServer routes.", + "title": "LazyLLM DocService API", + "version": "1.0.0" + }, + "openapi": "3.1.0", + "paths": { + "/v1/algo/list": { + "get": { + "operationId": "list_algo_v1_algo_list_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "List Algo" + } + }, + "/v1/algo/{algo_id}/groups": { + "get": { + "operationId": "get_algo_groups_v1_algo__algo_id__groups_get", + "parameters": [ + { + "in": "path", + "name": "algo_id", + "required": true, + "schema": { + "title": "Algo Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Algo Groups" + } + }, + "/v1/algorithms": { + "get": { + "operationId": "list_algorithms_v1_algorithms_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "List Algorithms" + } + }, + "/v1/algorithms/info": { + "post": { + "operationId": "get_algorithm_info_v1_algorithms_info_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/AlgorithmInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Algorithm Info" + } + }, + "/v1/chunks": { + "get": { + "operationId": "list_chunks_v1_chunks_get", + "parameters": [ + { + "in": "query", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + }, + { + "in": "query", + "name": "doc_id", + "required": true, + "schema": { + "title": "Doc Id", + "type": "string" + } + }, + { + "in": "query", + "name": "group", + "required": true, + "schema": { + "title": "Group", + "type": "string" + } + }, + { + "in": "query", + "name": "algo_id", + "required": false, + "schema": { + "default": "__default__", + "title": "Algo Id", + "type": "string" + } + }, + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + }, + { + "in": "query", + "name": "offset", + "required": false, + "schema": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "title": "Offset" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Chunks" + } + }, + "/v1/docs": { + "get": { + "operationId": "list_docs_v1_docs_get", + "parameters": [ + { + "in": "query", + "name": "kb_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Kb Id" + } + }, + { + "in": "query", + "name": "algo_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Algo Id" + } + }, + { + "in": "query", + "name": "keyword", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Keyword" + } + }, + { + "in": "query", + "name": "include_deleted_or_canceled", + "required": false, + "schema": { + "default": true, + "title": "Include Deleted Or Canceled", + "type": "boolean" + } + }, + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Docs" + } + }, + "/v1/docs/add": { + "post": { + "operationId": "add_v1_docs_add_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DocItemsRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Add" + } + }, + "/v1/docs/delete": { + "post": { + "operationId": "delete_v1_docs_delete_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/DeleteRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete" + } + }, + "/v1/docs/metadata/patch": { + "post": { + "operationId": "patch_metadata_v1_docs_metadata_patch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/MetadataPatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Patch Metadata" + } + }, + "/v1/docs/reparse": { + "post": { + "operationId": "reparse_v1_docs_reparse_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ReparseRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Reparse" + } + }, + "/v1/docs/transfer": { + "post": { + "operationId": "transfer_v1_docs_transfer_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TransferRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Transfer" + } + }, + "/v1/docs/upload": { + "post": { + "operationId": "upload_v1_docs_upload_post", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_upload_v1_docs_upload_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Upload" + } + }, + "/v1/docs/{doc_id}": { + "get": { + "operationId": "get_doc_v1_docs__doc_id__get", + "parameters": [ + { + "in": "path", + "name": "doc_id", + "required": true, + "schema": { + "title": "Doc Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Doc" + } + }, + "/v1/health": { + "get": { + "operationId": "health_v1_health_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "Health" + } + }, + "/v1/internal/callbacks/tasks": { + "post": { + "operationId": "task_callback_http_v1_internal_callbacks_tasks_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskCallbackPayload" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Task Callback Http" + } + }, + "/v1/internal/parser-url": { + "get": { + "operationId": "get_parser_url_v1_internal_parser_url_get", + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + } + }, + "summary": "Get Parser Url" + } + }, + "/v1/kbs": { + "delete": { + "operationId": "delete_kbs_v1_kbs_delete", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbDeleteBatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete Kbs" + }, + "get": { + "operationId": "list_kbs_v1_kbs_get", + "parameters": [ + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + }, + { + "in": "query", + "name": "keyword", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Keyword" + } + }, + { + "in": "query", + "name": "owner_id", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Owner Id" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Kbs" + }, + "post": { + "operationId": "create_kb_v1_kbs_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbCreateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Create Kb" + } + }, + "/v1/kbs/batch": { + "post": { + "operationId": "batch_get_kbs_v1_kbs_batch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbBatchQueryRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Batch Get Kbs" + } + }, + "/v1/kbs/{kb_id}": { + "delete": { + "operationId": "delete_kb_v1_kbs__kb_id__delete", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + }, + { + "in": "query", + "name": "idempotency_key", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Idempotency Key" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Delete Kb" + }, + "get": { + "operationId": "get_kb_v1_kbs__kb_id__get", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Kb" + } + }, + "/v1/kbs/{kb_id}/update": { + "post": { + "operationId": "update_kb_v1_kbs__kb_id__update_post", + "parameters": [ + { + "in": "path", + "name": "kb_id", + "required": true, + "schema": { + "title": "Kb Id", + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/KbUpdateRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Update Kb" + } + }, + "/v1/tasks": { + "get": { + "operationId": "list_tasks_v1_tasks_get", + "parameters": [ + { + "in": "query", + "name": "page", + "required": false, + "schema": { + "default": 1, + "title": "Page", + "type": "integer" + } + }, + { + "in": "query", + "name": "page_size", + "required": false, + "schema": { + "default": 20, + "title": "Page Size", + "type": "integer" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Status" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "List Tasks" + } + }, + "/v1/tasks/batch": { + "post": { + "operationId": "get_tasks_batch_v1_tasks_batch_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskBatchRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Tasks Batch" + } + }, + "/v1/tasks/cancel": { + "post": { + "operationId": "cancel_task_v1_tasks_cancel_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskCancelRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Cancel Task" + } + }, + "/v1/tasks/info": { + "post": { + "operationId": "get_task_info_v1_tasks_info_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TaskInfoRequest" + } + } + }, + "required": true + }, + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Task Info" + } + }, + "/v1/tasks/{task_id}": { + "get": { + "operationId": "get_task_v1_tasks__task_id__get", + "parameters": [ + { + "in": "path", + "name": "task_id", + "required": true, + "schema": { + "title": "Task Id", + "type": "string" + } + } + ], + "responses": { + "200": { + "content": { + "application/json": { + "schema": {} + } + }, + "description": "Successful Response" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + }, + "description": "Validation Error" + } + }, + "summary": "Get Task" + } + } + } +} \ No newline at end of file diff --git a/lazyllm/tools/rag/doc_service/doc_server.py b/lazyllm/tools/rag/doc_service/doc_server.py new file mode 100644 index 000000000..1c309de42 --- /dev/null +++ b/lazyllm/tools/rag/doc_service/doc_server.py @@ -0,0 +1,813 @@ +from __future__ import annotations + +import hashlib +import json +import os +import traceback +from typing import Any, Dict, List, Optional, Set + +import requests + +from lazyllm import LOG, FastapiApp as app, ModuleBase, ServerModule, UrlModule, once_wrapper +from lazyllm.thirdparty import fastapi + +from ..utils import BaseResponse, _get_default_db_config, ensure_call_endpoint +from .base import ( + AddFileItem, + AddRequest, + AlgorithmInfoRequest, + CallbackEventType, + DeleteRequest, + DocServiceError, + DocStatus, + KbBatchQueryRequest, + KbCreateRequest, + KbDeleteBatchRequest, + KbUpdateRequest, + MetadataPatchRequest, + ReparseRequest, + SourceType, + TaskBatchRequest, + TaskCallbackPayload, + TaskCallbackRequest, + TaskCancelRequest, + TaskInfoRequest, + TransferRequest, + UploadRequest, +) +from .doc_manager import DocManager +from .utils import sha256_file + +DEFAULT_OPENAPI_OUTPUT_PATH = os.path.join(os.path.dirname(__file__), 'doc_server.openapi.json') + + +class DocServer(ModuleBase): + class _Impl: + def __init__( + self, + storage_dir: str, + db_config: Optional[Dict[str, Any]] = None, + parser_db_config: Optional[Dict[str, Any]] = None, + parser_poll_interval: float = 0.05, + parser_url: Optional[str] = None, + callback_url: Optional[str] = None, + ): + if not parser_url: + raise ValueError('parser_url is required; doc_service no longer starts a mock parsing server') + self._storage_dir = storage_dir + self._db_config = db_config + self._parser_db_config = parser_db_config + self._parser_poll_interval = parser_poll_interval + self._parser_url = parser_url + self._callback_url = callback_url + self._parser = None + self._manager = None + + @once_wrapper(reset_on_pickle=True) + def _lazy_init(self): + os.makedirs(self._storage_dir, exist_ok=True) + self._manager = DocManager( + db_config=self._db_config, + parser_url=self._parser_url, + callback_url=self._callback_url, + ) + + def stop(self): + return None + + def set_runtime_callback_url(self, callback_url: str): + self._lazy_init() + self._manager.set_callback_url(callback_url) + + @staticmethod + def _response(data=None, code=200, msg='success', status_code=200): + payload = BaseResponse(code=code, msg=msg, data=data).model_dump(mode='json') + return fastapi.responses.JSONResponse(status_code=status_code, content=payload) + + def _run(self, func, *args, success_msg='success', **kwargs): + try: + data = func(*args, **kwargs) + return self._response(data=data, msg=success_msg) + except DocServiceError as exc: + data = dict(exc.data or {}) + data.setdefault('biz_code', exc.biz_code) + return self._response(data=data, code=exc.http_status, msg=exc.msg, status_code=exc.http_status) + except fastapi.HTTPException as exc: + detail = exc.detail if isinstance(exc.detail, dict) else {} + data = detail.get('data') + if isinstance(data, dict) and 'biz_code' not in data and detail.get('code'): + data['biz_code'] = detail['code'] + code = exc.status_code + msg = detail.get('msg', str(exc.detail)) + return self._response(data=data, code=code, msg=msg, status_code=exc.status_code) + + @staticmethod + def _build_upload_payload(request: UploadRequest, file_identities: Optional[List[Dict[str, Any]]] = None): + source_type = request.source_type or SourceType.API + items = file_identities + if items is None: + items = [] + for idx, item in enumerate(request.items): + content_hash = None + size_bytes = None + if os.path.exists(item.file_path): + content_hash = sha256_file(item.file_path) + size_bytes = os.path.getsize(item.file_path) + items.append({ + 'filename': os.path.basename(item.file_path), + 'content_hash': content_hash, + 'size_bytes': size_bytes, + 'doc_id': item.doc_id if idx == 0 else None, + }) + return { + 'kb_id': request.kb_id, + 'algo_id': request.algo_id, + 'source_type': source_type.value, + 'idempotency_key': request.idempotency_key, + 'items': items, + } + + @staticmethod + def _build_update_kb_payload(kb_id: str, request: KbUpdateRequest): + payload = request.model_dump(mode='json', exclude_unset=True) + payload['kb_id'] = kb_id + payload['explicit_fields'] = sorted(field for field in request.model_fields_set if field != 'kb_id') + return payload + + def _gen_unique_upload_path(self, filename: str, reserved_paths: Optional[set] = None): + safe_name = os.path.basename(filename) or 'upload.bin' + file_path = os.path.join(self._storage_dir, safe_name) + reserved_paths = reserved_paths or set() + if file_path not in reserved_paths and not os.path.exists(file_path): + return file_path + + suffix = os.path.splitext(safe_name)[1] + prefix = safe_name[:-len(suffix)] if suffix else safe_name + for idx in range(1, 10000): + candidate = os.path.join(self._storage_dir, f'{prefix}-{idx}{suffix}') + if candidate not in reserved_paths and not os.path.exists(candidate): + return candidate + digest = hashlib.sha256(safe_name.encode()).hexdigest()[:8] + return os.path.join(self._storage_dir, f'{prefix}-{digest}{suffix}') + + @staticmethod + async def _save_upload_file(upload_file: 'fastapi.UploadFile', file_path: str): + with open(file_path, 'wb') as fh: + while True: + chunk = await upload_file.read(1024 * 1024) + if not chunk: + break + fh.write(chunk) + await upload_file.close() + + async def _persist_uploads(self, files: List['fastapi.UploadFile']): + saved_paths = [] + file_identities = [] + reserved_paths: Set[str] = set() + for upload_file in files: + filename = getattr(upload_file, 'filename', None) or 'upload.bin' + file_path = self._gen_unique_upload_path(filename, reserved_paths) + await self._save_upload_file(upload_file, file_path) + reserved_paths.add(file_path) + saved_paths.append(file_path) + file_identities.append({ + 'filename': os.path.basename(file_path), + 'content_hash': sha256_file(file_path), + 'size_bytes': os.path.getsize(file_path), + 'doc_id': None, + }) + return saved_paths, file_identities + + def _run_upload(self, request: UploadRequest, payload: Optional[Dict[str, Any]] = None): + idem_payload = payload or self._build_upload_payload(request) + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/upload', request.idempotency_key, idem_payload, + lambda: {'items': self._manager.upload(request)} + )) + + @staticmethod + def _normalize_task_callback(callback: Any) -> TaskCallbackRequest: + if isinstance(callback, TaskCallbackRequest): + return callback + if not isinstance(callback, dict): + raise DocServiceError('E_INVALID_PARAM', 'invalid callback payload') + + payload = dict(callback.get('payload') or {}) + for field in ('task_type', 'doc_id', 'kb_id', 'algo_id'): + if callback.get(field) is not None and field not in payload: + payload[field] = callback[field] + + event_type = callback.get('event_type') + status = callback.get('status') + task_status = callback.get('task_status') + + try: + if status is not None: + normalized_status = DocStatus(status) + normalized_event_type = CallbackEventType(event_type) if event_type else ( + CallbackEventType.START + if normalized_status in (DocStatus.WAITING, DocStatus.WORKING) else CallbackEventType.FINISH + ) + elif task_status is not None: + normalized_status = DocStatus(task_status) + normalized_event_type = CallbackEventType(event_type) if event_type else ( + CallbackEventType.START + if normalized_status in (DocStatus.WAITING, DocStatus.WORKING) + else CallbackEventType.FINISH + ) + else: + raise DocServiceError('E_INVALID_PARAM', 'status or task_status is required') + except ValueError as exc: + raise DocServiceError('E_INVALID_PARAM', str(exc)) from exc + + callback_data = { + 'callback_id': callback.get('callback_id'), + 'task_id': callback.get('task_id'), + 'event_type': normalized_event_type, + 'status': normalized_status, + 'error_code': callback.get('error_code'), + 'error_msg': callback.get('error_msg'), + 'payload': payload, + } + return TaskCallbackRequest.model_validate({k: v for k, v in callback_data.items() if v is not None}) + + @staticmethod + def _format_task_view(task: Optional[Dict[str, Any]]): + if not isinstance(task, dict): + return task + return dict(task) + + def _format_task_response_data(self, data: Any): + if isinstance(data, dict) and isinstance(data.get('items'), list): + payload = dict(data) + payload['items'] = [self._format_task_view(item) for item in data['items']] + return payload + return self._format_task_view(data) + + def upload_request(self, request: UploadRequest): + self._lazy_init() + return self._run_upload(request) + + @app.post('/v1/docs/upload') + async def upload( + self, + files: List['fastapi.UploadFile'] = fastapi.File(...), # noqa: B008 + kb_id: Optional[str] = fastapi.Form(None), # noqa: B008 + algo_id: Optional[str] = fastapi.Form(None), # noqa: B008 + source_type: Optional[SourceType] = fastapi.Form(None), # noqa: B008 + doc_id: Optional[str] = fastapi.Form(None), # noqa: B008 + idempotency_key: Optional[str] = fastapi.Form(None), # noqa: B008 + ): + self._lazy_init() + if not files: + raise fastapi.HTTPException(status_code=400, detail='files is required') + kb_id = kb_id or '__default__' + algo_id = algo_id or '__default__' + source_type = source_type or SourceType.API + saved_paths, file_identities = await self._persist_uploads(files) + upload_request = UploadRequest( + items=[ + AddFileItem(file_path=path, doc_id=(doc_id if idx == 0 else None)) + for idx, path in enumerate(saved_paths) + ], + kb_id=kb_id, + algo_id=algo_id, + source_type=source_type, + idempotency_key=idempotency_key, + ) + if file_identities: + file_identities[0]['doc_id'] = doc_id + return self._run_upload(upload_request, self._build_upload_payload(upload_request, file_identities)) + + @app.post('/v1/docs/add') + def add(self, request: AddRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/add', request.idempotency_key, payload, lambda: {'items': self._manager.add_files(request)} + )) + + @app.post('/v1/docs/reparse') + def reparse(self, request: ReparseRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/reparse', request.idempotency_key, payload, + lambda: {'task_ids': self._manager.reparse(request)} + )) + + @app.post('/v1/docs/delete') + def delete(self, request: DeleteRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/delete', request.idempotency_key, payload, lambda: {'items': self._manager.delete(request)} + )) + + @app.post('/v1/docs/transfer') + def transfer(self, request: TransferRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/transfer', request.idempotency_key, payload, + lambda: {'items': self._manager.transfer(request)} + )) + + @app.get('/v1/docs') + def list_docs( + self, + status: Optional[List[str]] = None, + kb_id: Optional[str] = None, + algo_id: Optional[str] = None, + keyword: Optional[str] = None, + include_deleted_or_canceled: bool = True, + page: int = 1, + page_size: int = 20, + ): + self._lazy_init() + data = self._manager.list_docs( + status=status, + kb_id=kb_id, + algo_id=algo_id, + keyword=keyword, + include_deleted_or_canceled=include_deleted_or_canceled, + page=page, + page_size=page_size, + ) + return BaseResponse(code=200, msg='success', data=data) + + @app.get('/v1/docs/{doc_id}') + def get_doc(self, doc_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_doc_detail(doc_id)) + + @app.post('/v1/docs/metadata/patch') + def patch_metadata(self, request: MetadataPatchRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/docs/metadata/patch', request.idempotency_key, payload, + lambda: self._manager.patch_metadata(request) + )) + + @app.get('/v1/tasks') + def list_tasks(self, status: Optional[List[str]] = None, page: int = 1, page_size: int = 20): + self._lazy_init() + resp = self._manager.list_tasks(status, page, page_size) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) + + @app.get('/v1/tasks/{task_id}') + def get_task(self, task_id: str): + self._lazy_init() + resp = self._manager.get_task(task_id) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) + + def cancel_task_by_id(self, task_id: str): + self._lazy_init() + resp = self._manager.cancel_task(task_id) + return self._response(data=resp.data, code=resp.code, msg=resp.msg, status_code=resp.code) + + @app.post('/v1/tasks/cancel') + def cancel_task(self, request: TaskCancelRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + + def _cancel(): + resp = self._manager.cancel_task(request.task_id) + if resp.code == 404: + raise DocServiceError('E_NOT_FOUND', resp.msg, resp.data) + if resp.code == 409: + raise DocServiceError('E_STATE_CONFLICT', resp.msg, resp.data) + if resp.code != 200: + raise DocServiceError('E_INVALID_PARAM', resp.msg, resp.data) + return resp.data + return self._run(lambda: self._manager.run_idempotent( + '/v1/tasks/cancel', request.idempotency_key, payload, _cancel + )) + + def task_callback(self, callback: Any): + self._lazy_init() + return self._run(lambda: self._manager.on_task_callback(self._normalize_task_callback(callback))) + + @app.post('/v1/internal/callbacks/tasks') + def task_callback_http(self, request: TaskCallbackPayload): + return self.task_callback(request.model_dump(mode='json', exclude_none=True)) + + @app.get('/v1/algo/list') + def list_algo(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms()) + + @app.get('/v1/algo/{algo_id}/groups') + def get_algo_groups(self, algo_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_algo_groups(algo_id)) + + @app.get('/v1/algorithms') + def list_algorithms(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms_compat()) + + def list_algorithms_impl(self): + self._lazy_init() + return self._run(lambda: self._manager.list_algorithms_compat()) + + @app.post('/v1/algorithms/info') + def get_algorithm_info(self, request: AlgorithmInfoRequest): + self._lazy_init() + return self._run(lambda: self._manager.get_algorithm_info(request.algo_id)) + + def get_algorithm_info_impl(self, algo_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_algorithm_info(algo_id)) + + @app.get('/v1/chunks') + def list_chunks( + self, + kb_id: str, + doc_id: str, + group: str, + algo_id: str = '__default__', + page: int = 1, + page_size: int = 20, + offset: Optional[int] = None, + ): + self._lazy_init() + return self._run(lambda: self._manager.list_chunks( + kb_id=kb_id, doc_id=doc_id, group=group, algo_id=algo_id, + page=page, page_size=page_size, offset=offset, + )) + + @app.post('/v1/tasks/batch') + def get_tasks_batch(self, request: TaskBatchRequest): + self._lazy_init() + return self._run(lambda: self._manager.get_tasks_batch(request.task_ids)) + + def get_tasks_batch_impl(self, task_ids: List[str]): + self._lazy_init() + return self._run(lambda: self._manager.get_tasks_batch(task_ids)) + + @app.post('/v1/tasks/info') + def get_task_info(self, request: TaskInfoRequest): + self._lazy_init() + resp = self._manager.get_task(request.task_id) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) + + def get_task_info_impl(self, task_id: str): + self._lazy_init() + resp = self._manager.get_task(task_id) + return self._response( + data=self._format_task_response_data(resp.data), + code=resp.code, + msg=resp.msg, + status_code=resp.code, + ) + + @app.get('/v1/kbs') + def list_kbs( + self, + page: int = 1, + page_size: int = 20, + keyword: Optional[str] = None, + status: Optional[List[str]] = None, + owner_id: Optional[str] = None, + ): + self._lazy_init() + return self._run(lambda: self._manager.list_kbs( + page=page, + page_size=page_size, + keyword=keyword, + status=status, + owner_id=owner_id, + )) + + @app.get('/v1/kbs/{kb_id}') + def get_kb(self, kb_id: str): + self._lazy_init() + return self._run(lambda: self._manager.get_kb(kb_id)) + + def create_kb_by_id(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + self._lazy_init() + return self._run(lambda: self._manager.create_kb( + kb_id, + display_name=display_name, + description=description, + owner_id=owner_id, + meta=meta, + algo_id=algo_id, + )) + + @app.post('/v1/kbs') + def create_kb(self, request: KbCreateRequest): + self._lazy_init() + if not request.kb_id: + raise DocServiceError('E_INVALID_PARAM', 'kb_id is required') + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs', request.idempotency_key, payload, + lambda: self._manager.create_kb( + request.kb_id, + display_name=request.display_name, + description=request.description, + owner_id=request.owner_id, + meta=request.meta, + algo_id=request.algo_id, + ) + )) + + def update_kb_by_id(self, kb_id: str, request: KbUpdateRequest): + self._lazy_init() + if request.kb_id and request.kb_id != kb_id: + raise DocServiceError( + 'E_INVALID_PARAM', + f'kb_id mismatch: path={kb_id}, body={request.kb_id}', + {'kb_id': kb_id, 'request_kb_id': request.kb_id}, + ) + payload = self._build_update_kb_payload(kb_id, request) + return self._run(lambda: self._manager.run_idempotent( + f'/v1/kbs/{kb_id}:patch', request.idempotency_key, payload, + lambda: self._manager.update_kb( + kb_id, + display_name=request.display_name, + description=request.description, + owner_id=request.owner_id, + meta=request.meta, + algo_id=request.algo_id, + explicit_fields=set(request.model_fields_set), + ) + )) + + @app.post('/v1/kbs/{kb_id}/update') + def update_kb(self, kb_id: str, request: KbUpdateRequest): + return self.update_kb_by_id(kb_id, request) + + @app.post('/v1/kbs/batch') + def batch_get_kbs(self, request: KbBatchQueryRequest): + self._lazy_init() + return self._run(lambda: self._manager.batch_get_kbs(request.kb_ids)) + + @app.delete('/v1/kbs/{kb_id}') + def delete_kb(self, kb_id: str, idempotency_key: Optional[str] = None): + self._lazy_init() + payload = {'kb_id': kb_id} + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs/{kb_id}:delete', idempotency_key, payload, lambda: self._manager.delete_kb(kb_id) + )) + + @app.delete('/v1/kbs') + def delete_kbs(self, request: KbDeleteBatchRequest): + self._lazy_init() + payload = request.model_dump(mode='json') + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs:delete', request.idempotency_key, payload, lambda: self._manager.delete_kbs(request.kb_ids) + )) + + def delete_kbs_impl(self, kb_ids: List[str], idempotency_key: Optional[str] = None): + self._lazy_init() + payload = {'kb_ids': kb_ids} + return self._run(lambda: self._manager.run_idempotent( + '/v1/kbs:delete', idempotency_key, payload, lambda: self._manager.delete_kbs(kb_ids) + )) + + @app.get('/v1/health') + def health(self): + self._lazy_init() + return BaseResponse(code=200, msg='success', data=self._manager.health()) + + @app.get('/v1/internal/parser-url') + def get_parser_url(self): + self._lazy_init() + return BaseResponse(code=200, msg='success', data={'parser_url': self._parser_url}) + + def __call__(self, func_name: str, *args, **kwargs): + return getattr(self, func_name)(*args, **kwargs) + + def __init__( + self, + port: Optional[int] = None, + url: Optional[str] = None, + parser_url: Optional[str] = None, + db_config: Optional[Dict[str, Any]] = None, + parser_db_config: Optional[Dict[str, Any]] = None, + parser_poll_interval: float = 0.05, + storage_dir: Optional[str] = None, + callback_url: Optional[str] = None, + pythonpath: Optional[str] = None, + launcher=None, + ): + super().__init__() + self._raw_impl = None + self._storage_dir = storage_dir or os.path.join(os.getcwd(), '.doc_service_uploads') + self._db_config = db_config or _get_default_db_config('doc_service') + self._parser_db_config = parser_db_config or _get_default_db_config('doc_service_parser') + if url: + self._impl = UrlModule(url=ensure_call_endpoint(url)) + else: + if not parser_url: + raise ValueError('parser_url is required; doc_service no longer embeds a mock parsing server') + self._raw_impl = DocServer._Impl( + storage_dir=self._storage_dir, + db_config=self._db_config, + parser_db_config=self._parser_db_config, + parser_poll_interval=parser_poll_interval, + parser_url=parser_url, + callback_url=callback_url, + ) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher, pythonpath=pythonpath) + + @staticmethod + def _register_openapi_routes(openapi_app: 'fastapi.FastAPI', impl: 'DocServer._Impl'): + def _find_services(cls): + if '__relay_services__' not in dir(cls): + return + if '__relay_services__' in cls.__dict__: + for (method, path), (name, kw) in cls.__relay_services__.items(): + if getattr(impl.__class__, name) is getattr(cls, name): + route_method = getattr(openapi_app, 'get' if method == 'list' else method) + route_method(path, **kw)(getattr(impl, name)) + for base in cls.__bases__: + _find_services(base) + + app.update() + _find_services(impl.__class__) + + @classmethod + def build_openapi_app(cls, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + openapi_app = fastapi.FastAPI( + title=title, + version=version, + description='OpenAPI schema generated from current DocServer routes.', + ) + impl = cls._Impl( + storage_dir=os.path.join(os.getcwd(), '.doc_service_openapi'), + parser_url='http://127.0.0.1:9966', + ) + cls._register_openapi_routes(openapi_app, impl) + for route in openapi_app.routes: + body_field = getattr(route, 'body_field', None) + annotation = getattr(getattr(body_field, 'field_info', None), 'annotation', None) + if hasattr(annotation, 'model_rebuild'): + annotation.model_rebuild(force=True, _types_namespace=route.endpoint.__globals__) + return openapi_app + + @classmethod + def build_openapi_schema(cls, title: str = 'LazyLLM DocService API', version: str = '1.0.0'): + return cls.build_openapi_app(title=title, version=version).openapi() + + @classmethod + def export_openapi( + cls, + output_path: str = DEFAULT_OPENAPI_OUTPUT_PATH, + title: str = 'LazyLLM DocService API', + version: str = '1.0.0', + ): + schema = cls.build_openapi_schema(title=title, version=version) + output_dir = os.path.dirname(output_path) + if output_dir: + os.makedirs(output_dir, exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as fh: + json.dump(schema, fh, ensure_ascii=False, indent=2, sort_keys=True) + return output_path + + def start(self): + result = super().start() + if self._raw_impl and isinstance(self._impl, ServerModule): + try: + callback_url = self._impl._url.rsplit('/', 1)[0] + '/v1/internal/callbacks/tasks' + self._dispatch('set_runtime_callback_url', callback_url) + except Exception as exc: + LOG.warning(f'[DocServer] failed to set runtime callback url: {exc}') + return result + + def stop(self): + if self._raw_impl: + try: + self._dispatch('stop') + except Exception as exc: + LOG.warning(f'[DocServer] stop impl failed: {exc}, {traceback.format_exc()}') + if isinstance(self._impl, ServerModule): + self._impl.stop() + + @property + def url(self): + return self._impl._url + + @property + def _url(self): + return self.url + + @property + def parser_url(self): + if self._raw_impl: + return self._raw_impl._parser_url + base_url = self.url.rsplit('/', 1)[0] + try: + response = requests.get(f'{base_url}/v1/internal/parser-url', timeout=5) + response.raise_for_status() + return response.json()['data']['parser_url'] + except (requests.RequestException, KeyError, TypeError, ValueError) as exc: + LOG.warning(f'[DocServer] failed to resolve remote parser_url from {base_url}: {exc}') + return None + + @staticmethod + def _normalize_dispatch_result(result): + if isinstance(result, fastapi.responses.JSONResponse): + return json.loads(result.body.decode()) + return result + + def _dispatch(self, method: str, *args, **kwargs): + if isinstance(self._impl, ServerModule): + return self._normalize_dispatch_result(self._impl._call(method, *args, **kwargs)) + return self._normalize_dispatch_result(getattr(self._impl, method)(*args, **kwargs)) + + # Method-call style wrappers + def upload(self, request: UploadRequest): + return self._dispatch('upload_request', request) + + def add(self, request: AddRequest): + return self._dispatch('add', request) + + def reparse(self, request: ReparseRequest): + return self._dispatch('reparse', request) + + def delete(self, request: DeleteRequest): + return self._dispatch('delete', request) + + def transfer(self, request: TransferRequest): + return self._dispatch('transfer', request) + + def patch_metadata(self, request: MetadataPatchRequest): + return self._dispatch('patch_metadata', request) + + def list_docs(self, **kwargs): + return self._dispatch('list_docs', **kwargs) + + def get_doc(self, doc_id: str): + return self._dispatch('get_doc', doc_id) + + def list_tasks(self, **kwargs): + return self._dispatch('list_tasks', **kwargs) + + def get_tasks_batch(self, task_ids: List[str]): + return self._dispatch('get_tasks_batch_impl', task_ids) + + def get_task_info(self, task_id: str): + return self._dispatch('get_task_info_impl', task_id) + + def get_task(self, task_id: str): + return self._dispatch('get_task', task_id) + + def set_runtime_callback_url(self, callback_url: str): + return self._dispatch('set_runtime_callback_url', callback_url) + + def cancel_task(self, task_id: str): + return self._dispatch('cancel_task_by_id', task_id) + + def list_kbs(self, **kwargs): + return self._dispatch('list_kbs', **kwargs) + + def get_kb(self, kb_id: str): + return self._dispatch('get_kb', kb_id) + + def list_chunks(self, **kwargs): + return self._dispatch('list_chunks', **kwargs) + + def list_algorithms(self): + return self._dispatch('list_algorithms_impl') + + def get_algorithm_info(self, algo_id: str): + return self._dispatch('get_algorithm_info_impl', algo_id) + + def create_kb(self, kb_id: str, display_name: Optional[str] = None, description: Optional[str] = None, + owner_id: Optional[str] = None, meta: Optional[Dict[str, Any]] = None, + algo_id: str = '__default__'): + return self._dispatch('create_kb_by_id', kb_id, display_name, description, owner_id, meta, algo_id) + + def update_kb(self, kb_id: str, request: KbUpdateRequest): + return self._dispatch('update_kb_by_id', kb_id, request) + + def batch_get_kbs(self, kb_ids: List[str]): + return self._dispatch('batch_get_kbs', KbBatchQueryRequest(kb_ids=kb_ids)) + + def delete_kb(self, kb_id: str): + return self._dispatch('delete_kb', kb_id) + + def delete_kbs(self, kb_ids: List[str]): + return self._dispatch('delete_kbs_impl', kb_ids) diff --git a/lazyllm/tools/rag/doc_service/parser_client.py b/lazyllm/tools/rag/doc_service/parser_client.py new file mode 100644 index 000000000..1173c208b --- /dev/null +++ b/lazyllm/tools/rag/doc_service/parser_client.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import requests + +from ..parsing_service.base import ( + AddDocRequest as ParsingAddDocRequest, + CancelTaskRequest as ParsingCancelTaskRequest, + DeleteDocRequest as ParsingDeleteDocRequest, + FileInfo as ParsingFileInfo, + TransferParams as ParsingTransferParams, + UpdateMetaRequest as ParsingUpdateMetaRequest, +) +from ..utils import BaseResponse +from .utils import normalize_api_base_url + + +class ParserClient: + def __init__(self, parser_url: str): + self._parser_url = normalize_api_base_url(parser_url) + + def _request(self, method: str, path: str, **kwargs): + response = requests.request(method, f'{self._parser_url}{path}', timeout=8, **kwargs) + if response.status_code >= 400: + raise RuntimeError(f'parser http error: {response.status_code} {response.text}') + return response.json() + + def _get_with_fallback(self, paths: List[str], params: Optional[Dict[str, Any]] = None): + last_error = None + for path in paths: + try: + return self._request('GET', path, params=params) + except RuntimeError as exc: + last_error = exc + if '404' not in str(exc): + raise + if last_error is not None: + raise last_error + raise RuntimeError('parser http error: no path provided') + + def health(self): + return BaseResponse.model_validate(self._request('GET', '/health')) + + def add_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, file_path: str, + metadata: Optional[Dict[str, Any]] = None, reparse_group: Optional[str] = None, + callback_url: Optional[str] = None, transfer_params: Optional[Dict[str, Any]] = None): + req = ParsingAddDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo( + file_path=file_path, + doc_id=doc_id, + metadata=metadata or {}, + reparse_group=reparse_group, + transfer_params=( + ParsingTransferParams.model_validate(transfer_params) + if transfer_params is not None else None + ), + )], + ) + return BaseResponse.model_validate(self._request('POST', '/doc/add', json=req.model_dump(mode='json'))) + + def update_meta(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + metadata: Optional[Dict[str, Any]] = None, file_path: Optional[str] = None, + callback_url: Optional[str] = None): + req = ParsingUpdateMetaRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + callback_url=callback_url, + feedback_url=callback_url, + file_infos=[ParsingFileInfo(file_path=file_path, doc_id=doc_id, metadata=metadata or {})], + ) + return BaseResponse.model_validate( + self._request('POST', '/doc/meta/update', json=req.model_dump(mode='json')) + ) + + def delete_doc(self, task_id: str, algo_id: str, kb_id: str, doc_id: str, + callback_url: Optional[str] = None): + req = ParsingDeleteDocRequest( + task_id=task_id, + algo_id=algo_id, + kb_id=kb_id, + doc_ids=[doc_id], + callback_url=callback_url, + feedback_url=callback_url, + ) + return BaseResponse.model_validate(self._request('DELETE', '/doc/delete', json=req.model_dump(mode='json'))) + + def cancel_task(self, task_id: str): + req = ParsingCancelTaskRequest(task_id=task_id) + return BaseResponse.model_validate(self._request('POST', '/doc/cancel', json=req.model_dump(mode='json'))) + + def list_algorithms(self): + return BaseResponse.model_validate(self._get_with_fallback(['/v1/algo/list', '/algo/list'])) + + def get_algorithm_groups(self, algo_id: str): + try: + data = self._get_with_fallback([f'/v1/algo/{algo_id}/groups', f'/algo/{algo_id}/group/info']) + return BaseResponse.model_validate(data) + except RuntimeError as exc: + if '404' in str(exc): + return BaseResponse(code=404, msg='algo not found', data=None) + raise + + def list_doc_chunks(self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int, page_size: int): + data = self._request('GET', '/doc/chunks', params={ + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'group': group, + 'offset': offset, + 'page_size': page_size, + }) + return BaseResponse.model_validate(data) diff --git a/lazyllm/tools/rag/doc_service/utils.py b/lazyllm/tools/rag/doc_service/utils.py new file mode 100644 index 000000000..9a15d35cb --- /dev/null +++ b/lazyllm/tools/rag/doc_service/utils.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import hashlib +import json +import os +from typing import Any, Dict, Optional + +from lazyllm import LOG +from ..utils import gen_docid + + +def to_json(data: Optional[Dict[str, Any]]) -> str: + return json.dumps(data or {}, ensure_ascii=False) + + +def from_json(raw: Optional[str]) -> Dict[str, Any]: + if not raw: + return {} + try: + return json.loads(raw) + except Exception: + LOG.warning('[DocService] Failed to decode json payload') + return {} + + +def gen_doc_id(file_path: str, doc_id: Optional[str] = None) -> str: + return doc_id or gen_docid(file_path) + + +def stable_json(data: Any) -> str: + return json.dumps(data, ensure_ascii=False, sort_keys=True, default=str) + + +def hash_payload(data: Any) -> str: + return hashlib.sha256(stable_json(data).encode()).hexdigest() + + +def sha256_file(file_path: str) -> str: + digest = hashlib.sha256() + with open(file_path, 'rb') as fh: + for chunk in iter(lambda: fh.read(1024 * 1024), b''): + digest.update(chunk) + return digest.hexdigest() + + +def merge_transfer_metadata( + source_metadata: Dict[str, Any], target_metadata: Optional[Dict[str, Any]] +) -> Dict[str, Any]: + metadata = dict(source_metadata or {}) + if target_metadata: + metadata.update(target_metadata) + return metadata + + +def resolve_transfer_target_path( + source_path: str, target_filename: Optional[str], target_file_path: Optional[str] +) -> str: + if target_file_path: + return target_file_path + if target_filename: + base_dir = os.path.dirname(source_path) if source_path else '' + return os.path.join(base_dir, target_filename) if base_dir else target_filename + return source_path + + +def normalize_api_base_url(url: str) -> str: + url = url.rstrip('/') + if url.endswith('/_call') or url.endswith('/generate'): + return url.rsplit('/', 1)[0] + return url diff --git a/lazyllm/tools/rag/doc_to_db/extractor.py b/lazyllm/tools/rag/doc_to_db/extractor.py index 089692250..c39d3ffab 100644 --- a/lazyllm/tools/rag/doc_to_db/extractor.py +++ b/lazyllm/tools/rag/doc_to_db/extractor.py @@ -13,12 +13,12 @@ from lazyllm import LOG, ThreadPoolExecutor, once_wrapper from lazyllm.components import JsonFormatter -from lazyllm.module import LLMBase +from lazyllm.module import LLMBase, ModuleBase from ...sql.sql_manager import DBStatus, SqlManager from ..doc_node import DocNode from ..global_metadata import RAG_DOC_ID, RAG_KB_ID -from ..utils import DocListManager, _orm_to_dict +from ..utils import _orm_to_dict from ..store.store_base import DEFAULT_KB_ID from .model import ( TABLE_SCHEMA_SET_INFO, Table_ALGO_KB_SCHEMA, ExtractionMode, @@ -33,7 +33,7 @@ ONE_DOC_LENGTH_LIMIT = 102400 -class SchemaExtractor: +class SchemaExtractor(ModuleBase): '''Schema aware extractor that materializes BaseModel schemas into database tables.''' TABLE_PREFIX = 'lazyllm_schema' @@ -293,7 +293,7 @@ def has_schema_set(self, schema_set_id: str) -> bool: return True return schema_set_id in self._schema_registry - def register_schema_set_to_kb(self, algo_id: Optional[str] = DocListManager.DEFAULT_GROUP_NAME, + def register_schema_set_to_kb(self, algo_id: Optional[str] = DEFAULT_KB_ID, kb_id: Optional[str] = DEFAULT_KB_ID, schema_set_id: Optional[str] = None, schema_set: Type[BaseModel] = None, force_refresh: bool = False) -> str: ''' @@ -555,7 +555,7 @@ def _validate_extract_params(self, data: Union[str, List[DocNode]], algo_id: str return kb_id, doc_id, _orm_to_dict(bound) def extract_and_store(self, data: Union[str, List[DocNode]], # noqa: C901 - algo_id: str = DocListManager.DEFAULT_GROUP_NAME, + algo_id: str = DEFAULT_KB_ID, schema_set_id: str = None, schema_set: Type[BaseModel] = None) -> ExtractResult: '''Persist extracted fields for a document''' self._lazy_init() @@ -724,8 +724,8 @@ def _get_extract_data(self, algo_id: str, doc_ids: List[str], # noqa: C901 results.append(ExtractResult(data=row_data, metadata=meta)) return results - def __call__(self, data: Union[str, List[DocNode]], - algo_id: str = DocListManager.DEFAULT_GROUP_NAME) -> ExtractResult: + def forward(self, data: Union[str, List[DocNode]], + algo_id: str = DEFAULT_KB_ID) -> ExtractResult: # NOTE: data should be from single file source (kb_id, doc_id should be the same) self._lazy_init() res = self.extract_and_store(data=data, algo_id=algo_id) diff --git a/lazyllm/tools/rag/document.py b/lazyllm/tools/rag/document.py index 0b395e050..b8edd172d 100644 --- a/lazyllm/tools/rag/document.py +++ b/lazyllm/tools/rag/document.py @@ -1,5 +1,6 @@ import os -from typing import Callable, Optional, Dict, Union, List, Type, Set +import warnings +from typing import Callable, Optional, Dict, Union, List, Type, Set, Tuple from functools import cached_property from pydantic import BaseModel import lazyllm @@ -9,20 +10,26 @@ from lazyllm.tools.sql.sql_manager import SqlManager, DBStatus from lazyllm.common.bind import _MetaBind -from .doc_manager import DocManager +from .doc_service import DocServer from .doc_impl import DocImpl, StorePlaceholder, EmbedPlaceholder, BuiltinGroups, DocumentProcessor, NodeGroupType from .doc_node import DocNode from .doc_to_db import DocInfoSchema, DocToDbProcessor, extract_db_schema_from_files, SchemaExtractor from .store import LAZY_ROOT_NAME, EMBED_DEFAULT_KEY from .store.store_base import DEFAULT_KB_ID from .index_base import IndexBase -from .utils import DocListManager, ensure_call_endpoint +from .utils import RAG_DEFAULT_GROUP_NAME, ensure_call_endpoint from .global_metadata import GlobalMetadataDesc as DocField from .web import DocWebModule import copy import functools import weakref +_LOCAL_PYTHONPATH = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + + +def _is_local_map_store(store_conf: Optional[Dict]) -> bool: + return isinstance(store_conf, dict) and store_conf.get('type') == 'map' + class CallableDict(dict): def __call__(self, cls, *args, **kw): @@ -37,13 +44,25 @@ def __instancecheck__(self, __instance): class Document(ModuleBase, BuiltinGroups, metaclass=_MetaDocument): class _Manager(ModuleBase): - def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - manager: Union[bool, str] = False, server: Union[bool, int] = False, name: Optional[str] = None, - launcher: Optional[Launcher] = None, store_conf: Optional[Dict] = None, - doc_fields: Optional[Dict[str, DocField]] = None, cloud: bool = False, - doc_files: Optional[List[str]] = None, processor: Optional[DocumentProcessor] = None, - display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', - schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): + def __init__( # noqa: C901 + self, + dataset_path: Optional[str], + embed: Optional[Union[Callable, Dict[str, Callable]]] = None, + manager: Union[bool, str, DocServer] = False, + server: Union[bool, int] = False, + name: Optional[str] = None, + launcher: Optional[Launcher] = None, + store_conf: Optional[Dict] = None, + doc_fields: Optional[Dict[str, DocField]] = None, + cloud: bool = False, + doc_files: Optional[List[str]] = None, + processor: Optional[DocumentProcessor] = None, + display_name: Optional[str] = '', + description: Optional[str] = 'algorithm description', + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, + create_ui: bool = False, + enable_path_monitoring: Optional[bool] = None, + ): super().__init__() self._origin_path, self._doc_files, self._cloud = dataset_path, doc_files, cloud @@ -58,18 +77,51 @@ def __init__(self, dataset_path: Optional[str], embed: Optional[Union[Callable, self._dataset_path = dataset_path self._embed = self._get_embeds(embed) self._processor = processor - name = name or DocListManager.DEFAULT_GROUP_NAME + self._create_ui = create_ui + self._spawn_doc_server = False + self._doc_processor_started = False + compat_ui_manager = manager == 'ui' + if compat_ui_manager: + lazyllm.LOG.warning('`manager=\'ui\'` is deprecated, use `manager=True, create_ui=True` instead') + elif isinstance(manager, str): + raise ValueError(f'Unsupported manager value: {manager}') + spawn_doc_server = bool(manager) and not isinstance(manager, DocServer) + connect_doc_server = isinstance(manager, DocServer) + doc_impl_dataset_path = dataset_path if not (spawn_doc_server or connect_doc_server) else None + self._doc_impl_dataset_path = doc_impl_dataset_path + self._doc_processor = None + if spawn_doc_server: + self._spawn_doc_server = True + self._doc_processor = DocumentProcessor(launcher=self._launcher, pythonpath=_LOCAL_PYTHONPATH) + self._submodules.remove(self._doc_processor) + elif connect_doc_server: + self._manager = manager + parser_url = getattr(getattr(manager, '_raw_impl', None), '_parser_url', None) + if parser_url is None: + parser_url = manager.parser_url + if parser_url: + self._doc_processor = DocumentProcessor(url=parser_url) + self._schema_extractor = schema_extractor + self._store_conf = store_conf + self._display_name = display_name + self._description = description + name = name or RAG_DEFAULT_GROUP_NAME if not display_name: display_name = name - - self._dlm = None if (self._cloud or self._doc_files is not None) else DocListManager( - dataset_path, name, enable_path_monitoring=False if manager else True) + if enable_path_monitoring is None: + enable_path_monitoring = False if (spawn_doc_server or connect_doc_server or processor) else True + self._enable_path_monitoring = enable_path_monitoring + doc_processor = self._doc_processor or processor self._kbs = CallableDict({name: DocImpl( - embed=self._embed, dlm=self._dlm, doc_files=doc_files, global_metadata_desc=doc_fields, - store=store_conf, processor=processor, algo_name=name, display_name=display_name, + embed=self._embed, + dataset_path=self._doc_impl_dataset_path, + enable_path_monitoring=enable_path_monitoring, + doc_files=doc_files, + global_metadata_desc=doc_fields, + store=store_conf, processor=doc_processor, algo_name=name, display_name=display_name, description=description, schema_extractor=schema_extractor)}) - if manager: self._manager = ServerModule(DocManager(self._dlm), launcher=self._launcher) - if manager == 'ui': self._docweb = DocWebModule(doc_server=self._manager) + if create_ui and not self._spawn_doc_server: + self.ensure_doc_web() if server: self._kbs = ServerModule(self._kbs, port=(None if isinstance(server, bool) else int(server))) self._global_metadata_desc = doc_fields @@ -88,24 +140,66 @@ def web_url(self): if hasattr(self, '_docweb'): return self._docweb.url return None + def ensure_doc_web(self): + if hasattr(self, '_docweb'): + return self._docweb + if self._spawn_doc_server and not hasattr(self, '_manager'): + raise ValueError('`create_ui=True` with `manager=True` requires `Document.start()` before using the UI') + if not hasattr(self, '_manager') or not isinstance(self._manager, DocServer): + raise ValueError( + '`create_ui=True` requires an available DocServer. ' + 'Set `manager=True` or pass `manager=DocServer(...)`.' + ) + self._docweb = DocWebModule(doc_server=self._manager) + return self._docweb + + def _ensure_doc_processor_started(self): + if self._doc_processor and not self._doc_processor_started: + self._doc_processor.start() + self._doc_processor_started = True + + def _ensure_managed_services_started(self): + if self._spawn_doc_server: + self._ensure_doc_processor_started() + if not hasattr(self, '_manager'): + self._manager = DocServer( + launcher=self._launcher, + storage_dir=self._dataset_path, + parser_url=self._doc_processor.url, + pythonpath=_LOCAL_PYTHONPATH, + ) + self._manager.start() + if self._create_ui and not hasattr(self, '_docweb'): + self.ensure_doc_web() + self._docweb.start() + + def _get_deploy_tasks(self): + if self._spawn_doc_server and not hasattr(self, '_manager'): + return lazyllm.pipeline(self._ensure_managed_services_started) + return None + def _get_embeds(self, embed): embeds = embed if isinstance(embed, dict) else {EMBED_DEFAULT_KEY: embed} if embed else {} - for embed in embeds.values(): - if isinstance(embed, ModuleBase): - self._submodules.append(embed) + for index, module in enumerate(embeds.values()): + if isinstance(module, ModuleBase): + setattr(self, f'_embed_module_{index}', module) return embeds def add_kb_group(self, name, doc_fields: Optional[Dict[str, DocField]] = None, - store_conf: Optional[Dict] = None, - embed: Optional[Union[Callable, Dict[str, Callable]]] = None): + store_conf: Optional[Dict] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): embed = self._get_embeds(embed) if embed else self._embed - if isinstance(self._kbs, ServerModule): - self._kbs._impl._m[name] = DocImpl(dlm=self._dlm, embed=embed, kb_group_name=name, - global_metadata_desc=doc_fields, store=store_conf) - else: - self._kbs[name] = DocImpl(dlm=self._dlm, embed=self._embed, kb_group_name=name, - global_metadata_desc=doc_fields, store=store_conf) - self._dlm.add_kb_group(name=name) + schema_extractor = schema_extractor or self._schema_extractor + if isinstance(schema_extractor, ModuleBase): + setattr(self, f'_schema_extractor_{name}', schema_extractor) + impl = DocImpl( + dataset_path=self._doc_impl_dataset_path, embed=embed, kb_group_name=name, + enable_path_monitoring=self._enable_path_monitoring, global_metadata_desc=doc_fields, + store=store_conf or self._store_conf, processor=self._doc_processor or self._processor, + algo_name=name, display_name=name, description=self._description, + schema_extractor=schema_extractor, + ) + (self._kbs._impl._m if isinstance(self._kbs, ServerModule) else self._kbs)[name] = impl def get_doc_by_kb_group(self, name): return self._kbs._impl._m[name] if isinstance(self._kbs, ServerModule) else self._kbs[name] @@ -131,16 +225,27 @@ def __new__(cls, *args, **kw): return super().__new__(cls) def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Callable, Dict[str, Callable]]] = None, - create_ui: bool = False, manager: Union[bool, str, 'Document._Manager', DocumentProcessor] = False, - server: Union[bool, int] = False, name: Optional[str] = None, launcher: Optional[Launcher] = None, - doc_files: Optional[List[str]] = None, doc_fields: Dict[str, DocField] = None, + create_ui: bool = False, + manager: Union[bool, str, DocServer, 'Document._Manager', DocumentProcessor] = False, + server: Union[bool, int] = False, name: Optional[str] = None, + launcher: Optional[Launcher] = None, doc_files: Optional[List[str]] = None, + doc_fields: Dict[str, DocField] = None, store_conf: Optional[Dict] = None, display_name: Optional[str] = '', description: Optional[str] = 'algorithm description', - schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None): + schema_extractor: Optional[Union[LLMBase, SchemaExtractor]] = None, + enable_path_monitoring: Optional[bool] = None): super().__init__() - if create_ui: - lazyllm.LOG.warning('`create_ui` for Document is deprecated, use `manager` instead') - manager = create_ui + if isinstance(manager, str): + if manager != 'ui': + raise ValueError(f'Unsupported manager value: {manager}') + warnings.warn( + '`manager="ui"` is deprecated, use `manager=True, create_ui=True` instead.', + DeprecationWarning, + stacklevel=2, + ) + lazyllm.LOG.warning('`manager="ui"` is deprecated, use `manager=True, create_ui=True` instead') + create_ui = True + manager = True if isinstance(dataset_path, (tuple, list)): doc_fields = dataset_path dataset_path = None @@ -150,8 +255,7 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal assert store_conf is None or store_conf['type'] == 'map', ( 'Only map store is supported for Document with temp-files') - name = name or DocListManager.DEFAULT_GROUP_NAME - self._schema_extractor: SchemaExtractor = schema_extractor + name = name or RAG_DEFAULT_GROUP_NAME if isinstance(manager, Document._Manager): assert not server, 'Server infomation is already set to by manager' @@ -161,23 +265,28 @@ def __init__(self, dataset_path: Optional[str] = None, embed: Optional[Union[Cal if dataset_path != manager._dataset_path and dataset_path != manager._origin_path: raise RuntimeError(f'Document path mismatch, expected `{manager._dataset_path}`' f'while received `{dataset_path}`') - manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed) + manager.add_kb_group(name=name, doc_fields=doc_fields, store_conf=store_conf, embed=embed, + schema_extractor=schema_extractor) + if create_ui: + manager.ensure_doc_web() self._manager = manager self._curr_group = name else: if isinstance(manager, DocumentProcessor): + if store_conf is None: + raise ValueError('`store_conf` is required when `manager` is a DocumentProcessor') + if _is_local_map_store(store_conf): + raise ValueError('`manager=DocumentProcessor(...)` does not support pure local map store') processor, cloud = manager, True - processor.start() manager = False - assert name, '`Name` of Document is necessary when using cloud service' - assert store_conf.get('type') != 'map', 'Cloud manager is not supported when using map store' - assert not dataset_path, 'Cloud manager is not supported with local dataset path' else: cloud, processor = False, None self._manager = Document._Manager(dataset_path, embed, manager, server, name, launcher, store_conf, doc_fields, cloud=cloud, doc_files=doc_files, processor=processor, display_name=display_name, description=description, - schema_extractor=self._schema_extractor) + schema_extractor=schema_extractor, + create_ui=create_ui, + enable_path_monitoring=enable_path_monitoring) self._curr_group = name self._doc_to_db_processor: DocToDbProcessor = None self._graph_document: weakref.ref = None @@ -368,6 +477,9 @@ def register_index(self, index_type: str, index_cls: IndexBase, *args, **kwargs) def _forward(self, func_name: str, *args, **kw): return self._manager(self._curr_group, func_name, *args, **kw) + def start(self): + return super().start() + def find_parent(self, target) -> Callable: return functools.partial(self._forward, 'find_parent', group=target) @@ -394,9 +506,12 @@ def register_schema_set(self, schema_set: Type[BaseModel], kb_id: Optional[str] return self._forward('_register_schema_set', schema_set, kb_id, force_refresh) def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: - return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers) + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + return self._forward( + '_get_nodes', uids, doc_ids, group, kb_id, numbers, limit, offset, return_total, sort_by_number, + ) def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: @@ -414,7 +529,7 @@ def __init__(self, url: str, name: str = None): super().__init__() self._missing_keys = set(dir(Document)) - set(dir(UrlDocument)) self._manager = lazyllm.UrlModule(url=ensure_call_endpoint(url)) - self._curr_group = name or DocListManager.DEFAULT_GROUP_NAME + self._curr_group = name or RAG_DEFAULT_GROUP_NAME def _forward(self, func_name: str, *args, **kwargs): args = (self._curr_group, func_name, *args) @@ -427,9 +542,12 @@ def forward(self, *args, **kw): return self._forward('retrieve', *args, **kw) def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, - group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None - ) -> List[DocNode]: - return self._forward('_get_nodes', uids, doc_ids, group, kb_id, numbers) + group: Optional[str] = None, kb_id: Optional[str] = None, numbers: Optional[Set] = None, + limit: Optional[int] = None, offset: int = 0, return_total: bool = False, + sort_by_number: bool = False) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + return self._forward( + '_get_nodes', uids, doc_ids, group, kb_id, numbers, limit, offset, return_total, sort_by_number, + ) def get_window_nodes(self, node: DocNode, span: tuple[int, int] = (-5, 5), merge: bool = False) -> Union[List[DocNode], DocNode]: diff --git a/lazyllm/tools/rag/parsing_service/base.py b/lazyllm/tools/rag/parsing_service/base.py index b197da80c..19b4a65d3 100644 --- a/lazyllm/tools/rag/parsing_service/base.py +++ b/lazyllm/tools/rag/parsing_service/base.py @@ -42,10 +42,19 @@ class AddDocRequest(BaseModel): kb_id: Optional[str] = None file_infos: List[FileInfo] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info, feedback_url) is deprecated, will be removed in the future db_info: EmptyDBInfo = None feedback_url: Optional[str] = None + @model_validator(mode='before') + @classmethod + def normalize_deprecated_fields(cls, data): + if isinstance(data, dict) and not data.get('db_info'): + data = dict(data) + data['db_info'] = None + return data + class UpdateMetaRequest(BaseModel): task_id: str = Field(default_factory=lambda: str(uuid4())) @@ -53,8 +62,18 @@ class UpdateMetaRequest(BaseModel): kb_id: Optional[str] = None file_infos: List[FileInfo] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info) is deprecated, will be removed in the future db_info: EmptyDBInfo = None + feedback_url: Optional[str] = None + + @model_validator(mode='before') + @classmethod + def normalize_deprecated_fields(cls, data): + if isinstance(data, dict) and not data.get('db_info'): + data = dict(data) + data['db_info'] = None + return data class DeleteDocRequest(BaseModel): @@ -63,15 +82,21 @@ class DeleteDocRequest(BaseModel): kb_id: Optional[str] = None doc_ids: List[str] priority: Optional[int] = 0 + callback_url: Optional[str] = None # NOTE: (db_info) is deprecated, will be removed in the future db_info: EmptyDBInfo = None + feedback_url: Optional[str] = None @model_validator(mode='before') @classmethod - def _compat_dataset_id(cls, data): - if isinstance(data, dict) and not data.get('kb_id') and data.get('dataset_id'): - data = dict(data) + def normalize_legacy_fields(cls, data): + if not isinstance(data, dict): + return data + data = dict(data) + if not data.get('kb_id') and data.get('dataset_id'): data['kb_id'] = data['dataset_id'] + if not data.get('db_info'): + data['db_info'] = None return data @@ -82,9 +107,8 @@ class CancelTaskRequest(BaseModel): class TaskStatus(str, Enum): WAITING = 'WAITING' WORKING = 'WORKING' - CANCEL_REQUESTED = 'CANCEL_REQUESTED' CANCELED = 'CANCELED' - FINISHED = 'FINISHED' + SUCCESS = 'SUCCESS' FAILED = 'FAILED' @@ -183,9 +207,12 @@ def _resolve_add_doc_task_type(request: AddDocRequest) -> str: # noqa: C901 } # Finished task queue table +# NOTE: callback-related columns were appended after the initial queue schema. Existing deployments may still +# use an older table layout, but queue initialization already auto-adds any missing nullable columns in place +# via ``_SQLBasedQueue._ensure_columns_exist()``, so startup remains backward compatible without extra migration code. FINISHED_TASK_QUEUE_TABLE_INFO = { 'name': 'lazyllm_finished_task_queue', - 'comment': 'Finished task queue table', + 'comment': 'Finished task queue table; legacy tables are extended in place with new columns at startup', 'columns': [ {'name': 'id', 'data_type': 'integer', 'nullable': False, 'is_primary_key': True, 'comment': 'Auto increment ID'}, @@ -194,9 +221,13 @@ def _resolve_add_doc_task_type(request: AddDocRequest) -> str: # noqa: C901 {'name': 'task_type', 'data_type': 'string', 'nullable': False, 'comment': 'Task type: DOC_ADD, DOC_DELETE, DOC_UPDATE_META, DOC_REPARSE'}, {'name': 'task_status', 'data_type': 'string', 'nullable': False, - 'comment': 'Task status: WAITING, WORKING, CANCEL_REQUESTED, CANCELED, FINISHED, FAILED'}, + 'comment': 'Task status: WAITING, WORKING, CANCELED, SUCCESS, FAILED'}, {'name': 'finished_at', 'data_type': 'datetime', 'nullable': False, 'comment': 'Finish time (set when processing completes)', 'default': datetime.now}, + {'name': 'callback_url', 'data_type': 'string', 'nullable': True, + 'comment': 'Callback target url for built-in HTTP callback'}, + {'name': 'task_context_json', 'data_type': 'string', 'nullable': True, + 'comment': 'Serialized callback context used to build callback payload'}, {'name': 'error_code', 'data_type': 'string', 'nullable': True, 'default': '200', 'comment': 'Error code (varchar64)'}, {'name': 'error_msg', 'data_type': 'string', 'nullable': True, 'default': 'success', diff --git a/lazyllm/tools/rag/parsing_service/impl.py b/lazyllm/tools/rag/parsing_service/impl.py index cd520663d..d80a6a73f 100644 --- a/lazyllm/tools/rag/parsing_service/impl.py +++ b/lazyllm/tools/rag/parsing_service/impl.py @@ -5,7 +5,6 @@ from collections import defaultdict, deque from concurrent.futures import ThreadPoolExecutor from functools import cached_property - from lazyllm import LOG from ..data_loaders import DirectoryReader @@ -90,22 +89,31 @@ def store(self) -> _DocumentStore: def reader(self) -> DirectoryReader: return self._reader + @staticmethod + def _prepare_doc_inputs(input_files: List[str], ids: Optional[List[str]] = None, + metadatas: Optional[List[Dict[str, Any]]] = None, + kb_id: Optional[str] = None) -> tuple[List[str], List[Dict[str, Any]], str]: + ids = ids or [gen_docid(path) for path in input_files] + normalized_metadatas = metadatas or [{} for _ in input_files] + for i, (doc_id, path) in enumerate(zip(ids, input_files)): + metadata = normalized_metadatas[i] or {} + metadata.setdefault(RAG_DOC_ID, doc_id) + metadata.setdefault(RAG_DOC_PATH, path) + metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) + normalized_metadatas[i] = metadata + resolved_kb_id = normalized_metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id + return ids, normalized_metadatas, resolved_kb_id + def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # noqa: C901 metadatas: Optional[List[Dict[str, Any]]] = None, kb_id: Optional[str] = None, transfer_mode: Optional[str] = None, target_kb_id: Optional[str] = None, target_doc_ids: Optional[List[str]] = None, preloaded_root_nodes: Optional[Dict[str, List[DocNode]]] = None): + ids = ids or [] try: if not input_files: return add_start = time.time() - if not ids: ids = [gen_docid(path) for path in input_files] - if metadatas is None: - metadatas = [{} for _ in input_files] - for metadata, doc_id, path in zip(metadatas, ids, input_files): - metadata.setdefault(RAG_DOC_ID, doc_id) - metadata.setdefault(RAG_DOC_PATH, path) - metadata.setdefault(RAG_KB_ID, kb_id or DEFAULT_KB_ID) - kb_id = metadatas[0].get(RAG_KB_ID, DEFAULT_KB_ID) if kb_id is None else kb_id + ids, metadatas, kb_id = self._prepare_doc_inputs(input_files, ids, metadatas, kb_id) load_start = time.time() if transfer_mode is None: @@ -173,9 +181,27 @@ def add_doc(self, input_files: List[str], ids: Optional[List[str]] = None, # no LOG.info(f'[_Processor - add_doc] Add documents done! files:{input_files}, ' f'Total Time: {add_time}s, Data Loading Time: {load_time}s') except Exception as e: + cleanup_doc_ids = ids if transfer_mode is None else (target_doc_ids or []) + cleanup_kb_id = kb_id if transfer_mode is None else target_kb_id + self._cleanup_failed_add(cleanup_doc_ids, cleanup_kb_id, clear_schema=transfer_mode is None) LOG.error(f'Add documents failed: {e}, {traceback.format_exc()}') raise e + def _cleanup_failed_add(self, doc_ids: List[str], kb_id: Optional[str], clear_schema: bool) -> None: + if not doc_ids: + return + try: + self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id) + except Exception as cleanup_exc: + LOG.error(f'Failed to cleanup nodes for docs {doc_ids} in kb {kb_id}: {cleanup_exc}, ' + f'{traceback.format_exc()}') + if clear_schema and self._schema_extractor: + try: + self._schema_extractor._delete_extract_data(algo_id=self._algo_id, kb_id=kb_id, doc_ids=doc_ids) + except Exception as cleanup_exc: + LOG.error(f'Failed to cleanup schema data for docs {doc_ids} in kb {kb_id}: {cleanup_exc}, ' + f'{traceback.format_exc()}') + def close(self): self._thread_pool.shutdown(wait=True) self._thread_pool = None @@ -277,7 +303,7 @@ def _reparse_docs(self, group_name: str, doc_ids: List[str], doc_paths: List[str kb_id: str = None, **kwargs): if not metadatas: raise ValueError('metadatas is required for reparse') - kb_id = metadatas[0].get(RAG_KB_ID, None) if kb_id is None else kb_id + doc_ids, metadatas, kb_id = self._prepare_doc_inputs(doc_paths, doc_ids, metadatas, kb_id) if group_name == 'all': preloaded_root_nodes = self._reader.load_data(doc_paths, metadatas, split_nodes_by_type=True) self._store.remove_nodes(doc_ids=doc_ids, kb_id=kb_id) diff --git a/lazyllm/tools/rag/parsing_service/queue.py b/lazyllm/tools/rag/parsing_service/queue.py index 785cfeb45..5293c6323 100644 --- a/lazyllm/tools/rag/parsing_service/queue.py +++ b/lazyllm/tools/rag/parsing_service/queue.py @@ -2,6 +2,7 @@ from datetime import datetime, timedelta from lazyllm import LOG +import sqlalchemy from ...sql import SqlManager from ..utils import _orm_to_dict @@ -30,11 +31,44 @@ def __init__(self, table_name: str, columns: List[Dict[str, Any]], db_config: Di ] } ) + self._ensure_columns_exist() LOG.info(f'[SQLBasedQueue] Queue {self._table_name} initialized successfully') except Exception as e: LOG.error(f'[SQLBasedQueue] Failed to initialize queue {self._table_name}: {e}') raise + def _ensure_columns_exist(self): + # Keep queue tables backward compatible by adding newly introduced columns + # to existing tables instead of requiring a manual migration step. + inspector = sqlalchemy.inspect(self._sql_manager.engine) + existing_columns = {column['name'] for column in inspector.get_columns(self._table_name)} + missing_columns = [column for column in self._columns if column['name'] not in existing_columns] + if not missing_columns: + return + + for column in missing_columns: + sql_type = self._sql_manager._sql_type_for(column['data_type']) + if isinstance(sql_type, type): + sql_type = sql_type() + type_sql = sql_type.compile(dialect=self._sql_manager.engine.dialect) + nullable_sql = '' if column.get('nullable', True) else ' NOT NULL' + default_sql = '' + default_value = column.get('default') + if default_value is not None and not callable(default_value): + if isinstance(default_value, str): + escaped_default = default_value.replace("'", "''") + default_sql = f" DEFAULT '{escaped_default}'" + elif isinstance(default_value, bool): + default_sql = f' DEFAULT {int(default_value)}' + else: + default_sql = f' DEFAULT {default_value}' + statement = sqlalchemy.text( + f'ALTER TABLE "{self._table_name}" ADD COLUMN "{column["name"]}" {type_sql}{nullable_sql}{default_sql}' + ) + with self._sql_manager.engine.begin() as connection: + connection.execute(statement) + LOG.info(f'[SQLBasedQueue] Added missing column {column["name"]} to {self._table_name}') + def _build_query(self, session, filter_by: Dict[str, Any] = None): TableCls = self._sql_manager.get_table_orm_class(self._table_name) query = session.query(TableCls) @@ -214,3 +248,21 @@ def clear(self, filter_by: Dict[str, Any] = None) -> int: except Exception as e: LOG.error(f'[SQLBasedQueue] Failed to clear {self._table_name}: {e}') raise + + def update(self, filter_by: Dict[str, Any], **kwargs) -> Optional[Dict[str, Any]]: + try: + with self._sql_manager.get_session() as session: + query = self._build_query(session, filter_by) + record = query.with_for_update().first() + if not record: + return None + + for key, value in kwargs.items(): + setattr(record, key, value) + session.flush() + result = _orm_to_dict(record) + LOG.debug(f'[SQLBasedQueue] Updated record in {self._table_name}') + return result + except Exception as e: + LOG.error(f'[SQLBasedQueue] Failed to update {self._table_name}: {e}') + raise diff --git a/lazyllm/tools/rag/parsing_service/server.py b/lazyllm/tools/rag/parsing_service/server.py index f3df99970..f01ad5317 100644 --- a/lazyllm/tools/rag/parsing_service/server.py +++ b/lazyllm/tools/rag/parsing_service/server.py @@ -1,15 +1,19 @@ -import json import inspect +import json +import random import threading import time import traceback -from datetime import datetime -from typing import Any, Callable, Dict, Optional, List +from datetime import datetime, timedelta +from email.utils import parsedate_to_datetime +from typing import Any, Callable, Dict, List, Optional, Tuple +from uuid import NAMESPACE_URL, uuid5 from lazyllm import ( LOG, ModuleBase, ServerModule, UrlModule, FastapiApp as app, LazyLLMLaunchersBase as Launcher, load_obj, once_wrapper, dump_obj ) +import requests from lazyllm.thirdparty import fastapi from .base import ( @@ -27,18 +31,25 @@ from ..doc_to_db import SchemaExtractor from ...sql import SqlManager +CALLBACK_RETRY_MIN_INTERVAL = 5.0 +CALLBACK_RETRY_MAX_INTERVAL = 300.0 +CALLBACK_RETRY_MAX_ATTEMPTS = 5 + class DocumentProcessor(ModuleBase): - class _Impl(): + class _Impl: def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int = 1, post_func: Optional[Callable] = None, path_prefix: Optional[str] = None, + callback_url: Optional[str] = None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, - high_priority_workers: int = 1): + high_priority_workers: int = 1, callback_task_statuses: Optional[List[str]] = None, + callback_task_types: Optional[List[str]] = None): self._db_config = db_config self._num_workers = num_workers self._post_func = post_func + self._callback_url = self._normalize_callback_url(callback_url) if not self._check_post_func(): raise ValueError('Invalid post function!') self._shutdown = False @@ -51,6 +62,9 @@ def __init__(self, db_config: Optional[Dict[str, Any]] = None, num_workers: int else [TaskType.DOC_DELETE.value] ) self._high_priority_workers = max(high_priority_workers, 0) + self._callback_task_statuses = callback_task_statuses + self._callback_task_types = callback_task_types + self._callback_retry_attempts: Dict[int, int] = {} self._db_manager = None self._waiting_task_queue = None @@ -75,6 +89,8 @@ def _lazy_init(self): table_name=FINISHED_TASK_QUEUE_TABLE_INFO['name'], columns=FINISHED_TASK_QUEUE_TABLE_INFO['columns'], db_config=self._db_config, + order_by='finished_at', + order_desc=False, ) self._post_func_thread = threading.Thread(target=self.process_finished_task, daemon=True) @@ -99,6 +115,8 @@ def _lazy_init(self): lease_renew_interval=self._lease_renew_interval, high_priority_task_types=high_priority_types, high_priority_only=True, + callback_task_statuses=self._callback_task_statuses, + callback_task_types=self._callback_task_types, ) self._high_priority_workers_module.start() if normal_workers > 0: @@ -108,6 +126,8 @@ def _lazy_init(self): lease_duration=self._lease_duration, lease_renew_interval=self._lease_renew_interval, high_priority_task_types=high_priority_types, + callback_task_statuses=self._callback_task_statuses, + callback_task_types=self._callback_task_types, ) self._workers.start() LOG.info('[DocumentProcessor] Lazy initialization completed!') @@ -129,15 +149,19 @@ def process_finished_task(self): '''process finished task in background thread''' while True: try: - finished_task = self._finished_task_queue.dequeue() + finished_task = self._finished_task_queue.peek() if finished_task: - self._callback( - task_id=finished_task.get('task_id'), - task_type=finished_task.get('task_type'), - task_status=finished_task.get('task_status'), - error_code=finished_task.get('error_code'), - error_msg=finished_task.get('error_msg') - ) + if not self._is_callback_due(finished_task): + time.sleep(0.5) + continue + try: + self._callback(finished_task) + except Exception as exc: + self._schedule_callback_retry(finished_task, exc) + time.sleep(0.1) + continue + self._finished_task_queue.clear(filter_by={'id': finished_task['id']}) + self._callback_retry_attempts.pop(finished_task['id'], None) time.sleep(0.1) else: time.sleep(1) @@ -145,6 +169,159 @@ def process_finished_task(self): LOG.error(f'[DocumentProcessor] Failed to process finished task: {e}, {traceback.format_exc()}') time.sleep(10) + @staticmethod + def _normalize_callback_url(callback_url: Optional[str]) -> Optional[str]: + if not callback_url: + return None + return callback_url.rstrip('/') + + def set_callback_url(self, callback_url: Optional[str]): + self._callback_url = self._normalize_callback_url(callback_url) + + @staticmethod + def _normalize_queue_datetime(value: Any) -> Optional[datetime]: + if value is None or isinstance(value, datetime): + return value + if isinstance(value, str): + try: + return datetime.fromisoformat(value) + except ValueError: + return None + return None + + def _is_callback_due(self, finished_task: Dict[str, Any]) -> bool: + finished_at = self._normalize_queue_datetime(finished_task.get('finished_at')) + return finished_at is None or finished_at <= datetime.now() + + @staticmethod + def _load_task_context(finished_task: Dict[str, Any]) -> Dict[str, Any]: + task_context_json = finished_task.get('task_context_json') + if not task_context_json: + raise ValueError('task_context_json is missing in finished task queue') + try: + task_context = json.loads(task_context_json) + except json.JSONDecodeError as exc: + raise ValueError(f'invalid task_context_json: {exc}') from exc + if not isinstance(task_context, dict): + raise ValueError('task_context_json must decode to dict') + return task_context + + def _drop_callback_task(self, finished_task: Dict[str, Any], exc: Exception, attempt: int, reason: str): + LOG.error('[DocumentProcessor] Callback delivery dropped queue item.' + f' queue_id={finished_task.get("id")}' + f' task_id={finished_task.get("task_id")}' + f' task_status={finished_task.get("task_status")}' + f' reason={reason}' + f' attempts={attempt}' + f' callback_url={finished_task.get("callback_url")}' + f' task_context_json={finished_task.get("task_context_json")}' + f' error={type(exc).__name__}: {exc}') + self._finished_task_queue.clear(filter_by={'id': finished_task['id']}) + self._callback_retry_attempts.pop(finished_task['id'], None) + + @staticmethod + def _parse_retry_after_seconds(value: Optional[str]) -> Optional[float]: + if not value: + return None + try: + return max(float(value), 0.0) + except (TypeError, ValueError): + pass + try: + retry_at = parsedate_to_datetime(value) + except (TypeError, ValueError, IndexError, OverflowError): + return None + if retry_at.tzinfo is not None: + delay = (retry_at - datetime.now(retry_at.tzinfo)).total_seconds() + else: + delay = (retry_at - datetime.now()).total_seconds() + return max(delay, 0.0) + + def _should_retry_callback_error(self, exc: Exception) -> Tuple[bool, Optional[float], str]: + if isinstance(exc, ValueError): + return False, None, 'invalid_callback_payload' + if isinstance(exc, (requests.Timeout, requests.ConnectionError)): + return True, None, 'transient_network_error' + if isinstance(exc, requests.HTTPError): + response = exc.response + if response is None: + return True, None, 'http_error_without_response' + status_code = response.status_code + if status_code in (408, 425): + return True, None, f'http_{status_code}' + if status_code == 429: + return True, self._parse_retry_after_seconds(response.headers.get('Retry-After')), 'http_429' + if status_code >= 500: + return True, None, f'http_{status_code}' + if 400 <= status_code < 500: + return False, None, f'http_{status_code}' + return True, None, type(exc).__name__ + + def _schedule_callback_retry(self, finished_task: Dict[str, Any], exc: Exception) -> bool: + queue_id = finished_task['id'] + attempt = self._callback_retry_attempts.get(queue_id, 0) + 1 + self._callback_retry_attempts[queue_id] = attempt + should_retry, retry_after_seconds, reason = self._should_retry_callback_error(exc) + if not should_retry: + self._drop_callback_task(finished_task, exc, attempt, reason) + return False + if attempt >= CALLBACK_RETRY_MAX_ATTEMPTS: + self._drop_callback_task(finished_task, exc, attempt, 'retry_exhausted') + return False + delay = CALLBACK_RETRY_MIN_INTERVAL * (2 ** (attempt - 1)) + delay *= random.uniform(0.8, 1.2) + if retry_after_seconds is not None: + delay = max(delay, retry_after_seconds) + delay = min(delay, CALLBACK_RETRY_MAX_INTERVAL) + retry_at = datetime.now() + timedelta(seconds=delay) + self._finished_task_queue.update(filter_by={'id': queue_id}, finished_at=retry_at) + LOG.warning(f'[DocumentProcessor] Callback delivery failed for queue_id={queue_id},' + f' task_id={finished_task.get("task_id")}, retry in {delay:.1f}s' + f' (attempt={attempt}/{CALLBACK_RETRY_MAX_ATTEMPTS})' + f' reason={reason}' + f' error={type(exc).__name__}: {exc}') + return True + + def _resolve_callback_url(self, payload: Dict[str, Any]) -> Optional[str]: + return self._normalize_callback_url( + payload.get('callback_url') or payload.get('feedback_url') or self._callback_url + ) + + def _default_post_func(self, finished_task: Dict[str, Any]): + task_id = finished_task.get('task_id') + task_status = finished_task.get('task_status') + error_code = finished_task.get('error_code') + error_msg = finished_task.get('error_msg') + callback_url = self._normalize_callback_url(finished_task.get('callback_url')) + if not callback_url: + raise ValueError(f'callback_url is missing for task {task_id}') + task_context = self._load_task_context(finished_task) + + base_payload = {'task_type': task_context.get('task_type'), + 'kb_id': task_context.get('kb_id'), + 'algo_id': task_context.get('algo_id')} + items = task_context.get('items') or [{}] + for index, item in enumerate(items): + callback_payload = { + 'callback_id': str(uuid5(NAMESPACE_URL, f'{task_id}:{task_status}:{index}')), + 'task_id': task_id, + 'task_status': task_status, + 'payload': {k: v for k, v in {**base_payload, **item}.items() if v is not None}, + } + for field in ('task_type', 'kb_id', 'algo_id'): + if base_payload.get(field) is not None: + callback_payload[field] = base_payload[field] + if item.get('doc_id') is not None: + callback_payload['doc_id'] = item['doc_id'] + if error_code is not None: + callback_payload['error_code'] = error_code + if error_msg is not None: + callback_payload['error_msg'] = error_msg + + response = requests.post(callback_url, json=callback_payload, timeout=8) + response.raise_for_status() + return True + def register_algorithm(self, name: str, store: _DocumentStore, reader: DirectoryReader, node_groups: Dict[str, Dict], schema_extractor: Optional[SchemaExtractor] = None, display_name: Optional[str] = None, description: Optional[str] = None): @@ -266,6 +443,8 @@ def get_algo_list(self) -> None: 'algo_id': algo_dict.get('id'), 'display_name': algo_dict.get('display_name'), 'description': algo_dict.get('description'), + 'created_at': algo_dict.get('created_at'), + 'updated_at': algo_dict.get('updated_at'), }) if not data: LOG.warning('[DocumentProcessor] No algorithm registered') @@ -286,11 +465,10 @@ def get_algo_group_info(self, algo_id: str) -> None: LOG.error(f'[DocumentProcessor] Failed to get group info: {e}, {traceback.format_exc()}') raise fastapi.HTTPException(status_code=500, detail=f'Failed to get group info: {str(e)}') - @app.get('/group/info') def get_group_info(self, algo_id: str) -> None: return self.get_algo_group_info(algo_id) - def _get_algo_group_info_data(self, algo_id: str) -> List[Dict[str, Any]]: + def _get_algo_group_info_data(self, algo_id: str): algorithm = self._get_algo(algo_id) if algorithm is None: raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') @@ -307,6 +485,88 @@ def _get_algo_group_info_data(self, algo_id: str) -> List[Dict[str, Any]]: data.append(group_info) return data + @staticmethod + def _format_chunk_item(segment: Dict[str, Any]) -> Dict[str, Any]: + return { + 'uid': segment.get('uid'), + 'doc_id': segment.get('doc_id'), + 'kb_id': segment.get('kb_id'), + 'group': segment.get('group'), + 'number': segment.get('number', 0), + 'content': segment.get('content'), + 'type': segment.get('type'), + 'parent': segment.get('parent'), + 'metadata': segment.get('meta', {}), + 'global_metadata': segment.get('global_meta', {}), + 'answer': segment.get('answer', ''), + 'image_keys': segment.get('image_keys', []), + } + + def _list_doc_chunks_data( + self, algo_id: str, kb_id: str, doc_id: str, group: str, offset: int = 0, limit: int = 20 + ) -> Dict[str, Any]: + algorithm = self._get_algo(algo_id) + if algorithm is None: + raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') + info = load_obj(algorithm.get('info_pickle')) + store: _DocumentStore = info['store'] # type: ignore + node_groups = info.get('node_groups', {}) + if group not in node_groups or not store.is_group_active(group): + raise fastapi.HTTPException(status_code=400, detail=f'Invalid group {group}') + offset = max(offset, 0) + limit = max(limit, 1) + segments, total = store.get_segments( + doc_ids={doc_id}, + kb_id=kb_id, + group=group, + offset=offset, + limit=limit, + return_total=True, + sort_by_number=True, + ) + return { + 'items': [self._format_chunk_item(segment) for segment in segments], + 'total': total, + 'offset': offset, + 'page_size': limit, + } + + @app.get('/doc/chunks') + def list_doc_chunks( + self, + algo_id: str = '__default__', + kb_id: Optional[str] = None, + doc_id: Optional[str] = None, + group: Optional[str] = None, + offset: int = 0, + page_size: int = 20, + ): + self._lazy_init() + if self._shutdown: + raise fastapi.HTTPException(status_code=503, detail='Server is shutting down...') + if not kb_id: + raise fastapi.HTTPException(status_code=400, detail='kb_id is required') + if not doc_id: + raise fastapi.HTTPException(status_code=400, detail='doc_id is required') + if not group: + raise fastapi.HTTPException(status_code=400, detail='group is required') + data = self._list_doc_chunks_data( + algo_id=algo_id, + kb_id=kb_id, + doc_id=doc_id, + group=group, + offset=offset, + limit=page_size, + ) + return BaseResponse(code=200, msg='success', data=data) + + @staticmethod + def _resolve_add_task_type(request: AddDocRequest) -> str: + try: + return _resolve_add_doc_task_type(request) + except ValueError as exc: + raise fastapi.HTTPException(status_code=400, detail=str(exc)) from exc + @app.post('/doc/add') def add_doc(self, request: AddDocRequest): # noqa: C901 self._lazy_init() @@ -323,16 +583,13 @@ def add_doc(self, request: AddDocRequest): # noqa: C901 raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') # NOTE: No idempotency key check, should be handled by the caller! for file_info in file_infos: - parse_file_path = file_info.transformed_file_path if \ - file_info.transformed_file_path else file_info.file_path - file_info.file_path = create_file_path(parse_file_path, prefix=self._path_prefix) - - try: - task_type = _resolve_add_doc_task_type(request) - except ValueError as e: - raise fastapi.HTTPException(status_code=400, detail=str(e)) + if self._path_prefix: + file_info.file_path = create_file_path(path=file_info.file_path, prefix=self._path_prefix) + task_type = self._resolve_add_task_type(request) payload = request.model_dump() - LOG.info(f'[DocumentProcessor] Received add doc request: {payload}') + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: @@ -380,6 +637,9 @@ def update_meta(self, request: UpdateMetaRequest): raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') payload = request.model_dump() LOG.info(f'[DocumentProcessor] Received update meta request: {payload}') + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_UPDATE_META.value @@ -426,6 +686,9 @@ def delete_doc(self, request: DeleteDocRequest): raise fastapi.HTTPException(status_code=404, detail=f'Invalid algo_id {algo_id}') payload = request.model_dump() LOG.info(f'[DocumentProcessor] Received delete doc request: {payload}') + resolved_callback_url = self._resolve_callback_url(payload) + if resolved_callback_url: + payload['callback_url'] = resolved_callback_url payload_json = json.dumps(payload, ensure_ascii=False) try: task_type = TaskType.DOC_DELETE.value @@ -492,8 +755,12 @@ def cancel(self, request: CancelTaskRequest): def _check_post_func(self) -> bool: '''assert post function is callable and params include task_id, task_status, error_code, error_msg''' if not self._post_func: - LOG.warning('[DocumentProcessor] No post function configured,' - ' task status callback will not be performed!') + if self._callback_url: + LOG.info('[DocumentProcessor] No custom post function configured,' + ' using built-in HTTP callback') + else: + LOG.warning('[DocumentProcessor] No custom post function configured, built-in HTTP callback' + ' will only run when callback_url or feedback_url is provided in task request') return True if not callable(self._post_func): LOG.error('[DocumentProcessor] Post function is not callable') @@ -513,51 +780,36 @@ def _check_post_func(self) -> bool: return False return True - def _callback(self, task_id: str, task_type: str = None, - task_status: str = None, error_code: str = None, error_msg: str = None): + def _callback(self, finished_task: Optional[Dict[str, Any]] = None, **legacy_kwargs): '''callback to service''' - parts = [f'task_id={task_id}'] - if task_type: - parts.append(f'task_type={task_type}') - if task_status: - parts.append(f'status={task_status}') - has_error = ( - task_status not in (TaskStatus.FINISHED.value, None) - or error_code not in (None, '', '200') - or error_msg not in (None, '', 'success') - ) - if has_error: - if error_code not in (None, ''): - parts.append(f'error_code={error_code}') - if error_msg not in (None, ''): - parts.append(f'error_msg={error_msg}') - message = '[DocumentProcessor] Task callback: ' + ', '.join(parts) - if task_status == TaskStatus.FAILED.value: - LOG.error(message) - elif task_status in (TaskStatus.CANCELED.value, TaskStatus.CANCEL_REQUESTED.value): - LOG.warning(message) - else: - LOG.info(message) + if finished_task is None: + finished_task = legacy_kwargs + task_id = finished_task.get('task_id') + task_type = finished_task.get('task_type') + task_status = finished_task.get('task_status') + error_code = finished_task.get('error_code') + error_msg = finished_task.get('error_msg') + message = f'Task {task_id} callback status: {task_status}.' + if error_msg: + message += f' Error code: {error_code}, error_msg: {error_msg}.' + LOG.info(f'[DocumentProcessor] {message}') - if self._post_func: - try: - params = inspect.signature(self._post_func).parameters - accepts_task_type = ( - 'task_type' in params - or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()) - ) - if accepts_task_type: - self._post_func( - task_id=task_id, - task_type=task_type, - task_status=task_status, - error_code=error_code, - error_msg=error_msg, - ) + try: + if self._post_func: + if 'task_type' in self._post_func.__code__.co_varnames: + self._post_func(task_id, task_status, error_code, error_msg, task_type=task_type) else: self._post_func(task_id, task_status, error_code, error_msg) - except Exception as e: - LOG.error(f'[DocumentProcessor] Failed to call post function: {e}, {traceback.format_exc()}') + else: + self._default_post_func(finished_task) + except Exception as e: + LOG.error(f'[DocumentProcessor] Failed to call post function: {e}, {traceback.format_exc()}') + if self._post_func: + try: + self._default_post_func(finished_task) + except Exception: + raise e + else: raise e def __call__(self, func_name: str, *args, **kwargs): @@ -566,9 +818,11 @@ def __call__(self, func_name: str, *args, **kwargs): def __init__(self, port: int = None, url: str = None, num_workers: int = 1, db_config: Optional[Dict[str, Any]] = None, launcher: Optional[Launcher] = None, post_func: Optional[Callable] = None, - path_prefix: Optional[str] = None, lease_duration: float = 300.0, + path_prefix: Optional[str] = None, callback_url: Optional[str] = None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: Optional[List[str]] = None, - high_priority_workers: int = 1): + high_priority_workers: int = 1, pythonpath: Optional[str] = None, + callback_task_statuses: Optional[List[str]] = None, + callback_task_types: Optional[List[str]] = None): super().__init__() self._raw_impl = None # save the reference of the original Impl object self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') @@ -583,8 +837,11 @@ def __init__(self, port: int = None, url: str = None, num_workers: int = 1, lease_renew_interval=lease_renew_interval, high_priority_task_types=high_priority_task_types, high_priority_workers=high_priority_workers, + callback_url=callback_url, + callback_task_statuses=callback_task_statuses, + callback_task_types=callback_task_types, ) - self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher) + self._impl = ServerModule(self._raw_impl, port=port, launcher=launcher, pythonpath=pythonpath) else: self._impl = UrlModule(url=ensure_call_endpoint(url)) @@ -602,6 +859,22 @@ def start(self): raise return result + def wait(self): + impl = self._impl + if isinstance(impl, ServerModule): + return impl.wait() + LOG.warning('[DocumentProcessor] wait() is no-op in UrlModule mode') + + def set_callback_url(self, callback_url: Optional[str]): + if isinstance(self._impl, UrlModule): + raise RuntimeError('set_callback_url is only supported in local server mode') + return self._dispatch('set_callback_url', callback_url) + + @property + def url(self): + impl = self._impl + return impl._url if isinstance(impl, ServerModule) else impl.url + def _dispatch(self, method: str, *args, **kwargs): try: impl = self._impl diff --git a/lazyllm/tools/rag/parsing_service/worker.py b/lazyllm/tools/rag/parsing_service/worker.py index 2404cc0ec..9216e551e 100644 --- a/lazyllm/tools/rag/parsing_service/worker.py +++ b/lazyllm/tools/rag/parsing_service/worker.py @@ -26,7 +26,8 @@ class DocumentProcessorWorker(ModuleBase): class _Impl(): def __init__(self, db_config: dict = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, - high_priority_only: bool = False, poll_mode: str = 'direct'): + high_priority_only: bool = False, poll_mode: str = 'thread', + callback_task_statuses: list[str] = None, callback_task_types: list[str] = None): self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._shutdown = False self._processors: dict[str, _Processor] = {} # algo_id -> _Processor @@ -49,6 +50,10 @@ def __init__(self, db_config: dict = None, task_poller=None, lease_duration: flo self._lease_renew_interval = lease_renew_interval self._high_priority_task_types = set(high_priority_task_types or []) self._high_priority_only = high_priority_only + self._callback_task_statuses = {TaskStatus(status).value for status in callback_task_statuses} \ + if callback_task_statuses else None + self._callback_task_types = {TaskType(task_type).value for task_type in callback_task_types} \ + if callback_task_types else None @once_wrapper(reset_on_pickle=True) def _lazy_init(self): @@ -126,6 +131,8 @@ def _fail_in_progress_task(self, reason: str): return task_id = self._in_progress_task.get('task_id') task_type = self._in_progress_task.get('task_type') + callback_url = self._in_progress_task.get('callback_url') + task_context_json = self._in_progress_task.get('task_context_json') if task_id and task_type: self._enqueue_finished_task( task_id=task_id, @@ -133,6 +140,8 @@ def _fail_in_progress_task(self, reason: str): task_status=TaskStatus.FAILED, error_code='PRESTOP', error_msg=reason, + callback_url=callback_url, + task_context_json=task_context_json, ) deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -250,6 +259,7 @@ def _exec_reparse_task( def _exec_transfer_task(self, processor: _Processor, task_id: str, payload: dict): try: + self._validate_transfer_payload(payload) file_infos = payload.get('file_infos') kb_id = payload.get('kb_id', None) input_files = [] @@ -297,9 +307,86 @@ def _exec_update_meta_task(self, processor: _Processor, task_id: str, payload: d LOG.error(f'{self._log_prefix(task_id)} Execute update meta task failed: {e}') raise e + @staticmethod + def _resolve_callback_url(payload: dict): + return payload.get('callback_url') or payload.get('feedback_url') + + @staticmethod + def _build_task_context(task_type: str, payload: dict) -> dict: + items = [] + if task_type in (TaskType.DOC_ADD.value, TaskType.DOC_REPARSE.value, TaskType.DOC_UPDATE_META.value): + file_infos = payload.get('file_infos') or [] + items = [{ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'metadata': file_info.get('metadata'), + 'reparse_group': file_info.get('reparse_group'), + } for file_info in file_infos] + elif task_type == TaskType.DOC_DELETE.value: + items = [{'doc_id': doc_id} for doc_id in (payload.get('doc_ids') or [])] + elif task_type == TaskType.DOC_TRANSFER.value: + file_infos = payload.get('file_infos') or [] + items = [{ + 'doc_id': file_info.get('doc_id'), + 'file_path': file_info.get('file_path'), + 'metadata': file_info.get('metadata'), + 'target_doc_id': (file_info.get('transfer_params') or {}).get('target_doc_id'), + 'target_kb_id': (file_info.get('transfer_params') or {}).get('target_kb_id'), + 'transfer_mode': (file_info.get('transfer_params') or {}).get('mode'), + } for file_info in file_infos] + if not items: + items = [{}] + return { + 'task_type': task_type, + 'kb_id': payload.get('kb_id'), + 'algo_id': payload.get('algo_id'), + 'items': items, + } + def _resolve_task_type(self, request: AddDocRequest) -> str: return _resolve_add_doc_task_type(request) + @staticmethod + def _validate_transfer_payload(payload: dict): # noqa: C901 + file_infos = payload.get('file_infos') + if not isinstance(file_infos, list) or not file_infos: + raise ValueError('file_infos is required for task_type DOC_TRANSFER') + transfer_mode = None + target_kb_id = None + target_algo_id = None + target_doc_ids = set() + request_algo_id = payload.get('algo_id') + for idx, file_info in enumerate(file_infos): + transfer_params = file_info.get('transfer_params') + if not isinstance(transfer_params, dict) or not transfer_params: + raise ValueError(f'transfer_params is required for task_type DOC_TRANSFER at index {idx}') + current_mode = transfer_params.get('mode') + current_target_kb_id = transfer_params.get('target_kb_id') + current_target_algo_id = transfer_params.get('target_algo_id') + current_target_doc_id = transfer_params.get('target_doc_id') + if current_mode not in ('cp', 'mv'): + raise ValueError('transfer_params.mode must be one of [cp, mv]') + if not current_target_kb_id: + raise ValueError('transfer_params.target_kb_id is required for task_type DOC_TRANSFER') + if not current_target_algo_id: + raise ValueError('transfer_params.target_algo_id is required for task_type DOC_TRANSFER') + if not current_target_doc_id: + raise ValueError('transfer_params.target_doc_id is required for task_type DOC_TRANSFER') + if transfer_mode is not None and transfer_mode != current_mode: + raise ValueError('transfer_params.mode must be the same for all files') + if target_kb_id is not None and target_kb_id != current_target_kb_id: + raise ValueError('transfer_params.target_kb_id must be the same for all files') + if target_algo_id is not None and target_algo_id != current_target_algo_id: + raise ValueError('transfer_params.target_algo_id must be the same for all files') + if request_algo_id is not None and current_target_algo_id != request_algo_id: + raise ValueError('transfer_params.target_algo_id must match request.algo_id') + if current_target_doc_id in target_doc_ids: + raise ValueError('transfer_params.target_doc_id must be unique for all files') + transfer_mode = current_mode + target_kb_id = current_target_kb_id + target_algo_id = current_target_algo_id + target_doc_ids.add(current_target_doc_id) + def _validate_task_payload(self, task_type: str, payload: dict): if not isinstance(payload, dict): raise ValueError('payload must be a dict') @@ -316,6 +403,8 @@ def _validate_task_payload(self, task_type: str, payload: dict): doc_ids = payload.get('doc_ids') if not isinstance(doc_ids, list) or not doc_ids: raise ValueError('doc_ids is required for task_type DOC_DELETE') + if task_type == TaskType.DOC_TRANSFER.value: + self._validate_transfer_payload(payload) def _summarize_task_payload(self, task_type: str, payload: dict) -> str: summary = { @@ -380,20 +469,27 @@ def _enqueue_task_from_payload(self, task: dict): def _parse_task_payload(self, task: dict): task_type = task.get('task_type') if task_type == TaskType.DOC_DELETE.value: - task_info = DeleteDocRequest(**task) + task_info = DeleteDocRequest.model_validate(task) elif task_type == TaskType.DOC_UPDATE_META.value: - task_info = UpdateMetaRequest(**task) + task_info = UpdateMetaRequest.model_validate(task) else: - task_info = AddDocRequest(**task) + task_info = AddDocRequest.model_validate(task) task_type = task_type or self._resolve_task_type(task_info) task_id = task_info.task_id - payload = task_info.model_dump() + payload = task_info.model_dump(mode='json') self._validate_task_payload(task_type, payload) return task_id, task_type, payload def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: bool): # noqa: C901 + callback_url = self._resolve_callback_url(payload) + task_context_json = json.dumps(self._build_task_context(task_type, payload), ensure_ascii=False) try: - self._in_progress_task = {'task_id': task_id, 'task_type': task_type} + self._in_progress_task = { + 'task_id': task_id, + 'task_type': task_type, + 'callback_url': callback_url, + 'task_context_json': task_context_json, + } if from_queue: self._start_lease_renewal(task_id) algo_id = payload.get('algo_id') @@ -402,6 +498,13 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo LOG.info(f'{self._log_prefix(task_id)} Start processing task: ' f'{self._summarize_task_payload(task_type, payload)}') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.WORKING, + callback_url=callback_url, + task_context_json=task_context_json, + ) processor = self._get_or_create_processor(algo_id) if task_type == TaskType.DOC_ADD.value: @@ -417,8 +520,15 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo else: raise ValueError(f'{self._log_prefix(task_id)} Unknown task type: {task_type}') - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FINISHED, - error_code='200', error_msg='success') + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.SUCCESS, + error_code='200', + error_msg='success', + callback_url=callback_url, + task_context_json=task_context_json, + ) if from_queue: deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -428,8 +538,15 @@ def _run_task(self, task_id: str, task_type: str, payload: dict, from_queue: boo except Exception as e: LOG.error(f'{self._log_prefix(task_id)} Failed to run task: {e}, {traceback.format_exc()}') if task_id and task_type: - self._enqueue_finished_task(task_id=task_id, task_type=task_type, task_status=TaskStatus.FAILED, - error_code=type(e).__name__, error_msg=str(e)) + self._enqueue_finished_task( + task_id=task_id, + task_type=task_type, + task_status=TaskStatus.FAILED, + error_code=type(e).__name__, + error_msg=str(e), + callback_url=callback_url, + task_context_json=task_context_json, + ) if from_queue: deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -471,19 +588,28 @@ def _poll_task(self): ) def _enqueue_finished_task(self, task_id: str, task_type: str, task_status: TaskStatus, - error_code: str = None, error_msg: str = None): + error_code: str = None, error_msg: str = None, + callback_url: str = None, task_context_json: str = None): try: self._lazy_init() + if self._callback_task_statuses and task_status.value not in self._callback_task_statuses: + return + if self._callback_task_types and task_type not in self._callback_task_types: + return self._finished_task_queue.enqueue( task_id=task_id, task_type=task_type, task_status=task_status.value, finished_at=datetime.now(), + callback_url=callback_url, + task_context_json=task_context_json, error_code=error_code if error_code else '200', error_msg=error_msg if error_msg else 'success' ) - if task_status == TaskStatus.FINISHED: - LOG.info(f'{self._log_prefix(task_id)} Task finished successfully') + if task_status == TaskStatus.WORKING: + LOG.info(f'{self._log_prefix(task_id)} Task started') + elif task_status == TaskStatus.SUCCESS: + LOG.info(f'{self._log_prefix(task_id)} Task completed successfully') else: LOG.error(f'{self._log_prefix(task_id)} Task completed with status {task_status}: {error_msg}') except Exception as e: @@ -498,8 +624,18 @@ def _worker_impl(self): # noqa: C901 time.sleep(WORKER_ERROR_RETRY_INTERVAL) continue if task_data: + payload = None + callback_url = None + task_context_json = None try: payload = json.loads(task_data.get('message')) + callback_url = self._resolve_callback_url(payload) if isinstance(payload, dict) else None + if isinstance(payload, dict): + task_context_json = json.dumps( + self._build_task_context(task_data['task_type'], payload), + ensure_ascii=False, + ) + self._validate_task_payload(task_data['task_type'], payload) except Exception as e: task_id = task_data.get('task_id') task_type = task_data.get('task_type') @@ -513,6 +649,8 @@ def _worker_impl(self): # noqa: C901 task_status=TaskStatus.FAILED, error_code=type(e).__name__, error_msg=str(e), + callback_url=callback_url, + task_context_json=task_context_json, ) deleted = self._waiting_task_queue.delete( filter_by={'task_id': task_id, 'worker_id': self._worker_id} @@ -587,7 +725,8 @@ def shutdown(self): def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = None, task_poller=None, lease_duration: float = 300.0, lease_renew_interval: float = 60.0, high_priority_task_types: list[str] = None, high_priority_only: bool = False, - poll_mode: str = 'direct'): + poll_mode: str = 'thread', callback_task_statuses: list[str] = None, + callback_task_types: list[str] = None): super().__init__() self._db_config = db_config if db_config else _get_default_db_config('doc_task_management') self._num_workers = num_workers @@ -600,6 +739,8 @@ def __init__(self, db_config: dict = None, num_workers: int = 1, port: int = Non high_priority_task_types=high_priority_task_types, high_priority_only=high_priority_only, poll_mode=poll_mode, + callback_task_statuses=callback_task_statuses, + callback_task_types=callback_task_types, ) self._worker_impl = ServerModule(worker_impl, port=self._port, num_replicas=self._num_workers) LOG.info(f'[DocumentProcessorWorker] Worker initialized with {num_workers} workers') @@ -618,6 +759,12 @@ def start(self): LOG.info('[DocumentProcessorWorker] Worker started') return result + def wait(self): + impl = self._worker_impl + if isinstance(impl, ServerModule): + return impl.wait() + LOG.warning('[DocumentProcessorWorker] wait() is no-op in local mode') + def stop(self): LOG.info('[DocumentProcessorWorker] Stopping worker...') self._dispatch('shutdown') diff --git a/lazyllm/tools/rag/store/document_store.py b/lazyllm/tools/rag/store/document_store.py index e1e9b4894..b84b89102 100644 --- a/lazyllm/tools/rag/store/document_store.py +++ b/lazyllm/tools/rag/store/document_store.py @@ -1,4 +1,6 @@ +import hashlib import os +import re import traceback import lazyllm from collections import defaultdict @@ -20,6 +22,9 @@ class _DocumentStore(object): + _COLLECTION_NAME_PATTERN = re.compile(r'[^a-z0-9_]+') + _COLLECTION_NAME_MAX_LEN = 255 + def __init__(self, algo_name: str, store: Union[Dict, LazyLLMStoreBase], group_embed_keys: Optional[Dict[str, Set[str]]] = None, embed: Optional[Dict[str, Callable]] = None, embed_dims: Optional[Dict[str, int]] = None, embed_datatypes: Optional[Dict[str, DataType]] = None, @@ -153,6 +158,11 @@ def is_group_active(self, group: str) -> bool: def is_group_empty(self, group: str) -> bool: return not self.impl.get(self._gen_collection_name(group), {}, limit=10) + def _upsert_segments(self, group: str, segments: List[dict]) -> None: + collection_name = self._gen_collection_name(group) + if not self.impl.upsert(collection_name, segments): + raise RuntimeError(f'[_DocumentStore - {self._algo_name}] Failed to upsert segments for group {group}') + def update_nodes(self, nodes: List[DocNode], copy: bool = False): # noqa: C901 if not nodes: return @@ -171,7 +181,7 @@ def update_nodes(self, nodes: List[DocNode], copy: bool = False): # noqa: C901 LOG.warning(f'[_DocumentStore - {self._algo_name}] Group {group} is not active, skip') continue for i in range(0, len(segments), INSERT_BATCH_SIZE): - self.impl.upsert(self._gen_collection_name(group), segments[i:i + INSERT_BATCH_SIZE]) + self._upsert_segments(group, segments[i:i + INSERT_BATCH_SIZE]) # update indices for index in self._indices.values(): index.update(nodes) @@ -213,11 +223,13 @@ def remove_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, group: Optional[str] = None, kb_id: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, return_total: bool = False, - numbers: Optional[Set] = None, **kwargs) -> Union[List[DocNode], Tuple[List[DocNode], int]]: + numbers: Optional[Set] = None, sort_by_number: bool = False, + **kwargs) -> Union[List[DocNode], Tuple[List[DocNode], int]]: try: result = self.get_segments(uids=uids, doc_ids=doc_ids, group=group, kb_id=kb_id, numbers=numbers, limit=limit, - offset=offset, return_total=return_total, **kwargs) + offset=offset, return_total=return_total, + sort_by_number=sort_by_number, **kwargs) if return_total: segments, total = result return [self._deserialize_node(segment) for segment in segments], total @@ -229,7 +241,8 @@ def get_nodes(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = N def get_segments(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] = None, group: Optional[str] = None, kb_id: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, return_total: bool = False, - numbers: Optional[Set] = None, **kwargs) -> Union[List[dict], Tuple[List[dict], int]]: + numbers: Optional[Set] = None, sort_by_number: bool = False, + **kwargs) -> Union[List[dict], Tuple[List[dict], int]]: # get a set of segments by uids # get the segments of the whole file -- doc ids only # get the segments of a certain group for one file -- doc ids and group (kb_id is optional) @@ -240,12 +253,28 @@ def get_segments(self, uids: Optional[List[str]] = None, doc_ids: Optional[Set] limit, offset = self._normalize_pagination(limit, offset) criteria = self._build_get_criteria(uids, doc_ids, kb_id, numbers, kwargs.get('parent')) groups = self._resolve_groups(group) + if self._can_use_native_segment_pagination(groups, sort_by_number): + result = self.impl.get( + self._gen_collection_name(groups[0]), + criteria, + limit=limit, + offset=offset, + return_total=return_total, + sort_by_number=sort_by_number, + **kwargs, + ) + if return_total: + segments, total = result if isinstance(result, tuple) else (result, len(result)) + return segments, total + return result[0] if isinstance(result, tuple) else result segments = [] for group in groups: if not self.is_group_active(group): LOG.warning(f'[_DocumentStore - {self._algo_name}] Group {group} is not active, skip') continue segments.extend(self.impl.get(self._gen_collection_name(group), criteria, **kwargs)) + if sort_by_number: + segments = self._sort_segments_by_number(segments) total = len(segments) segments = self._slice_segments(segments, limit, offset) return (segments, total) if return_total else segments @@ -264,7 +293,7 @@ def update_doc_meta(self, doc_id: str, metadata: dict, kb_id: str = None) -> Non segment['global_meta'].update(metadata) group_segments[segment.get('group')].append(segment) for group, segments in group_segments.items(): - self.impl.upsert(self._gen_collection_name(group), segments) + self._upsert_segments(group, segments) LOG.info(f'[_DocumentStore] Updated metadata for doc_id: {doc_id} in dataset: {kb_id}') return @@ -283,11 +312,21 @@ def _slice_segments(segments: List[dict], limit: Optional[int], offset: int) -> return segments[offset:end] return segments + @staticmethod + def _sort_segments_by_number(segments: List[dict]) -> List[dict]: + return sorted(segments, key=lambda segment: (segment.get('number', 0), segment.get('uid', ''))) + def _resolve_groups(self, group: Optional[str]) -> List[str]: if not group: return sorted(self._activated_groups) return [group] + def _can_use_native_segment_pagination(self, groups: List[str], sort_by_number: bool) -> bool: + if len(groups) != 1 or not sort_by_number: + return False + store = getattr(self.impl, 'segment_store', self.impl) + return store.__class__.__name__ in {'MapStore', 'OpenSearchStore', 'ElasticSearchStore'} + def _build_get_criteria(self, uids: Optional[List[str]], doc_ids: Optional[Set], kb_id: Optional[str], numbers: Optional[Set] = None, parent: Optional[Union[str, List[str]]] = None) -> Dict[str, Any]: @@ -440,4 +479,16 @@ def _deserialize_node(self, data: dict, score: Optional[float] = None) -> DocNod return node.with_sim_score(score) if score else node def _gen_collection_name(self, group: str) -> str: - return f'col_{self._algo_name}_{group}'.lower() + raw_name = f'col_{self._algo_name}_{group}'.lower() + normalized = self._COLLECTION_NAME_PATTERN.sub('_', raw_name).strip('_') + if not normalized: + normalized = 'col' + if normalized[0].isdigit(): + normalized = f'col_{normalized}' + if normalized == raw_name and len(normalized) <= self._COLLECTION_NAME_MAX_LEN: + return normalized + + digest = hashlib.sha1(raw_name.encode()).hexdigest()[:12] + max_prefix_len = self._COLLECTION_NAME_MAX_LEN - len(digest) - 1 + prefix = normalized[:max_prefix_len].rstrip('_') or 'col' + return f'{prefix}_{digest}' diff --git a/lazyllm/tools/rag/store/hybrid/hybrid_store.py b/lazyllm/tools/rag/store/hybrid/hybrid_store.py index 5b3c1b788..b9ecc7902 100644 --- a/lazyllm/tools/rag/store/hybrid/hybrid_store.py +++ b/lazyllm/tools/rag/store/hybrid/hybrid_store.py @@ -37,9 +37,13 @@ def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs @override def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: res_segments = self.segment_store.get(collection_name=collection_name, criteria=criteria, **kwargs) - if not res_segments: return [] + total = None + if isinstance(res_segments, tuple): + res_segments, total = res_segments + if not res_segments: + return ([], total or 0) if total is not None else [] uids = [item.get('uid') for item in res_segments] - res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids}, **kwargs) + res_vectors = self.vector_store.get(collection_name=collection_name, criteria={'uid': uids}) data = {} for item in res_segments: @@ -50,7 +54,8 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - else: raise ValueError(f'[HybridStore - get] uid {item["uid"]} in vector store' ' but not found in segment store') - return list(data.values()) + ordered = [data[item.get('uid')] for item in res_segments if item.get('uid') in data] + return (ordered, total) if total is not None else ordered @override def search(self, collection_name: str, query: str, query_embedding: Optional[Union[dict, List[float]]] = None, diff --git a/lazyllm/tools/rag/store/hybrid/map_store.py b/lazyllm/tools/rag/store/hybrid/map_store.py index 6dbe98162..3a7d20ab9 100644 --- a/lazyllm/tools/rag/store/hybrid/map_store.py +++ b/lazyllm/tools/rag/store/hybrid/map_store.py @@ -179,24 +179,48 @@ def _remove_uid(uid: str, use_discard: bool) -> None: @override def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) if self._sqlite_first: with self._lock: conn = self._open_conn() cur = conn.cursor() self._ensure_table(cur, collection_name) where, args = self._build_where(criteria) + order_clause = ' ORDER BY number ASC, uid ASC' if sort_by_number else '' + page_clause = '' + page_args = () + if limit is not None: + page_clause = ' LIMIT ? OFFSET ?' + page_args = (limit, offset) + elif offset > 0: + page_clause = ' LIMIT -1 OFFSET ?' + page_args = (offset,) + total = None + if return_total: + cur.execute(f'SELECT COUNT(*) FROM {collection_name}{where}', args) + total = cur.fetchone()[0] cur.execute(f'''SELECT uid, doc_id, "group", content, meta, global_meta, type, number, kb_id, excluded_embed_metadata_keys, excluded_llm_metadata_keys, parent, answer, image_keys - FROM {collection_name}{where}''', args) + FROM {collection_name}{where}{order_clause}{page_clause}''', args + page_args) rows = cur.fetchall() res = [] for r in rows: item = self._deserialize_data(r) res.append(item) - return res + return (res, total) if return_total else res else: uids = self._get_uids_by_criteria(collection_name, criteria) - return [self._uid2data[uid] for uid in uids if uid in self._uid2data] + res = [self._uid2data[uid] for uid in uids if uid in self._uid2data] + if sort_by_number: + res = sorted(res, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(res) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + res = res[offset:end] + return (res, total) if return_total else res def _build_where(self, criteria: dict): if not criteria: diff --git a/lazyllm/tools/rag/store/segment/elasticsearch_store.py b/lazyllm/tools/rag/store/segment/elasticsearch_store.py index e2678645e..817fcc770 100644 --- a/lazyllm/tools/rag/store/segment/elasticsearch_store.py +++ b/lazyllm/tools/rag/store/segment/elasticsearch_store.py @@ -211,6 +211,10 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - results: List[dict] = [] criteria = dict(criteria) if criteria else {} + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) # Query by primary key(mget) if criteria and self._primary_key in criteria: vals = criteria.pop(self._primary_key) @@ -224,6 +228,28 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - seg = self._transform_segment(doc) if seg: results.append(seg) + if sort_by_number: + results = sorted(results, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(results) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + results = results[offset:end] + return (results, total) if return_total else results + elif sort_by_number and (limit is not None or offset > 0 or return_total): + body = self._construct_criteria(criteria) or {'query': {'match_all': {}}} + body['sort'] = [{'number': {'order': 'asc'}}, {'_id': {'order': 'asc'}}] + if offset > 0: + body['from'] = offset + if limit is not None: + body['size'] = limit + elif offset > 0: + body['size'] = 10000 + if return_total: + body['track_total_hits'] = True + resp = self._client.search(index=collection_name, body=body) + results = [self._transform_segment(hit) for hit in resp['hits']['hits']] + total = resp['hits']['total']['value'] if return_total else len(results) + return (results, total) if return_total else results else: helpers = elasticsearch.helpers @@ -318,7 +344,19 @@ def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: vals = [vals] return {'query': {'ids': {'values': vals}}} + exact_match_fields = {'doc_id', 'kb_id', 'group', 'parent'} + def _add_clause(key, val): + if key in exact_match_fields: + clauses = [] + if isinstance(val, list): + clauses.append({'terms': {key: val}}) + clauses.append({'terms': {f'{key}.keyword': val}}) + else: + clauses.append({'term': {key: val}}) + clauses.append({'term': {f'{key}.keyword': val}}) + must_clauses.append({'bool': {'should': clauses, 'minimum_should_match': 1}}) + return if isinstance(val, list): must_clauses.append({'terms': {key: val}}) else: diff --git a/lazyllm/tools/rag/store/segment/opensearch_store.py b/lazyllm/tools/rag/store/segment/opensearch_store.py index 696169316..0368bdfcd 100644 --- a/lazyllm/tools/rag/store/segment/opensearch_store.py +++ b/lazyllm/tools/rag/store/segment/opensearch_store.py @@ -156,13 +156,17 @@ def delete(self, collection_name: str, criteria: Optional[dict] = None, **kwargs return False @override - def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: + def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) -> List[dict]: # noqa: C901 try: if not self._client.indices.exists(index=collection_name): LOG.warning(f'[OpenSearchStore - get] Index {collection_name} does not exist') return [] results: List[dict] = [] criteria = dict(criteria) if criteria else {} + limit = kwargs.get('limit') + offset = max(kwargs.get('offset', 0) or 0, 0) + return_total = kwargs.get('return_total', False) + sort_by_number = kwargs.get('sort_by_number', False) if criteria and self._primary_key in criteria: vals = criteria.pop(self._primary_key) if not isinstance(vals, list): @@ -172,6 +176,28 @@ def get(self, collection_name: str, criteria: Optional[dict] = None, **kwargs) - for doc in resp['docs']: if doc.get('found', False): results.append(self._transform_segment(doc)) + if sort_by_number: + results = sorted(results, key=lambda item: (item.get('number', 0), item.get('uid', ''))) + total = len(results) + if offset > 0 or limit is not None: + end = None if limit is None else offset + limit + results = results[offset:end] + return (results, total) if return_total else results + elif sort_by_number and (limit is not None or offset > 0 or return_total): + body = self._construct_criteria(criteria) or {'query': {'match_all': {}}} + body['sort'] = [{'number': {'order': 'asc'}}, {'_id': {'order': 'asc'}}] + if offset > 0: + body['from'] = offset + if limit is not None: + body['size'] = limit + elif offset > 0: + body['size'] = 10000 + if return_total: + body['track_total_hits'] = True + resp = self._client.search(index=collection_name, body=body) + results = [self._transform_segment(hit) for hit in resp['hits']['hits']] + total = resp['hits']['total']['value'] if return_total else len(results) + return (results, total) if return_total else results else: spec = importlib.util.find_spec('opensearchpy.helpers') helpers = importlib.util.module_from_spec(spec) @@ -248,9 +274,15 @@ def _serialize_node(self, segment: dict): def _deserialize_node(self, segment: dict) -> dict: seg = dict(segment) if self._global_metadata_desc and self._global_metadata_desc == BUILDIN_GLOBAL_META_DESC: - seg['meta'] = json.loads(seg.get('meta', '{}')) - seg['global_meta'] = json.loads(seg.get('global_meta', '{}')) - seg['image_keys'] = json.loads(seg.get('image_keys', '[]')) + meta = seg.get('meta', '{}') + global_meta = seg.get('global_meta', '{}') + image_keys = seg.get('image_keys', '[]') + if isinstance(meta, (str, bytes, bytearray)): + seg['meta'] = json.loads(meta) + if isinstance(global_meta, (str, bytes, bytearray)): + seg['global_meta'] = json.loads(global_meta) + if isinstance(image_keys, (str, bytes, bytearray)): + seg['image_keys'] = json.loads(image_keys) return seg def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: C901 @@ -263,7 +295,19 @@ def _construct_criteria(self, criteria: Optional[dict] = None) -> dict: # noqa: vals = [vals] return {'query': {'ids': {'values': vals}}} + exact_match_fields = {'doc_id', 'kb_id', 'group', 'parent'} + def _add_clause(key, val): + if key in exact_match_fields: + clauses = [] + if isinstance(val, list): + clauses.append({'terms': {key: val}}) + clauses.append({'terms': {f'{key}.keyword': val}}) + else: + clauses.append({'term': {key: val}}) + clauses.append({'term': {f'{key}.keyword': val}}) + must_clauses.append({'bool': {'should': clauses, 'minimum_should_match': 1}}) + return if isinstance(val, list): must_clauses.append({'terms': {key: val}}) else: diff --git a/lazyllm/tools/rag/transform/base.py b/lazyllm/tools/rag/transform/base.py index 67d64b8d9..52cac8313 100644 --- a/lazyllm/tools/rag/transform/base.py +++ b/lazyllm/tools/rag/transform/base.py @@ -1,12 +1,11 @@ from copy import copy as copy_obj -from enum import Enum from dataclasses import dataclass, field from typing import ( Any, List, Union, Optional, Tuple, AbstractSet, Collection, Literal, Callable, Dict, Iterator ) from lazyllm import LOG -from ..doc_node import DocNode, RichDocNode +from ..doc_node import DocNode, RichDocNode, MetadataMode from ....common.deprecated import deprecated from lazyllm import ThreadPoolExecutor from itertools import chain @@ -21,12 +20,6 @@ from lazyllm.thirdparty import nltk from lazyllm.thirdparty import transformers -class MetadataMode(str, Enum): - ALL = 'ALL' - EMBED = 'EMBED' - LLM = 'LLM' - NONE = 'NONE' - @dataclass class _Split: text: str diff --git a/lazyllm/tools/rag/utils.py b/lazyllm/tools/rag/utils.py index c6b0eebe9..3fc373991 100644 --- a/lazyllm/tools/rag/utils.py +++ b/lazyllm/tools/rag/utils.py @@ -24,7 +24,7 @@ from sqlalchemy.orm import DeclarativeBase, sessionmaker import lazyllm -from lazyllm import config +from lazyllm import config, deprecated from lazyllm.common import override from lazyllm.common.queue import sqlite3_check_threadsafety from lazyllm.thirdparty import tarfile @@ -45,6 +45,8 @@ config.add('default_dlmanager', str, 'sqlite', 'DEFAULT_DOCLIST_MANAGER', description='The default document list manager for RAG.') +RAG_DEFAULT_GROUP_NAME = '__default__' + def gen_docid(file_path: str) -> str: return hashlib.sha256(file_path.encode()).hexdigest() @@ -102,8 +104,9 @@ class DocPathParsingResult(BaseModel): msg: str is_new: bool = False +@deprecated('Document(dataset_path=..., enable_path_monitoring=...)') class DocListManager(ABC): - DEFAULT_GROUP_NAME = '__default__' + DEFAULT_GROUP_NAME = RAG_DEFAULT_GROUP_NAME __pool__ = dict() class Status: @@ -883,6 +886,13 @@ def _check_batch(fn): batch_size = getattr(fn, 'batch_size', None) return isinstance(batch_size, Integral) and batch_size > 1 + def _check_empty_embedding_item(vec, embed_key: str, idx: int) -> None: + if vec is None: + raise ValueError(f'[LazyLLM - parallel_do_embedding][{embed_key}] invalid embedding at index {idx}: None') + if isinstance(vec, (list, dict)) and len(vec) == 0: + raise ValueError(f'[LazyLLM - parallel_do_embedding][{embed_key}] ' + f'invalid embedding at index {idx}: empty {type(vec).__name__}') + def _process_key(k: str, knodes: List[DocNode]): try: fn = embed[k] @@ -892,7 +902,8 @@ def _process_key(k: str, knodes: List[DocNode]): if len(vecs) != len(texts): raise ValueError(f'[LazyLLM - parallel_do_embedding][{k}] batch size mismatch: ' f'[text_num:{len(texts)}] vs [vec_num:{len(vecs)}]') - for n, v in zip(knodes, vecs): + for idx, (n, v) in enumerate(zip(knodes, vecs)): + _check_empty_embedding_item(v, k, idx) n.set_embedding(k, v) return diff --git a/lazyllm/tools/rag/web.py b/lazyllm/tools/rag/web.py index e0a7768fc..addba74ad 100644 --- a/lazyllm/tools/rag/web.py +++ b/lazyllm/tools/rag/web.py @@ -51,10 +51,13 @@ def delete_group(self, group_name: str): return response.json()['msg'] def list_groups(self): - response = requests.get( - f'{self.base_url}/list_kb_groups', headers=self.basic_headers(False) - ) - return response.json()['data'] + response = requests.get(f'{self.base_url}/list_kb_groups', headers=self.basic_headers(False)) + payload = response.json() + if 'data' in payload: + return payload['data'] + response = requests.get(f'{self.base_url}/v1/kbs', headers=self.basic_headers(False)) + payload = response.json() + return [item.get('kb_id') for item in payload.get('data', {}).get('items', []) if item.get('kb_id')] def upload_files(self, group_name: str, override: bool = True): response = requests.post( @@ -240,10 +243,11 @@ def wait(self): self.demo.block_thread() def stop(self): - if self.demo: - self.demo.close() + demo = self.__dict__.get('demo') + if demo: + demo.close() del self.demo - self.demo, self.url = None, '' + self.url = '' def _find_can_use_network_port(self): for port in self.port: diff --git a/lazyllm/tools/servers/mineru/mineru_patches.py b/lazyllm/tools/servers/mineru/mineru_patches.py index 4d061dd9e..1b4ec05b4 100644 --- a/lazyllm/tools/servers/mineru/mineru_patches.py +++ b/lazyllm/tools/servers/mineru/mineru_patches.py @@ -6,17 +6,24 @@ merge_para_with_text as pipeline_merge_para_with_text, get_title_level, ) -from mineru.backend.vlm import ( # noqa: NID002 - vlm_middle_json_mkcontent, -) -from mineru.backend.vlm.vlm_middle_json_mkcontent import ( # noqa: NID002 - merge_para_with_text as vlm_merge_para_with_text, -) from mineru.utils.enum_class import ( # noqa: NID002 BlockType, ContentType, ) +try: + from mineru.backend.vlm import ( # noqa: NID002 + vlm_middle_json_mkcontent, + ) + from mineru.backend.vlm.vlm_middle_json_mkcontent import ( # noqa: NID002 + merge_para_with_text as vlm_merge_para_with_text, + ) + HAS_VLM_BACKEND = True +except Exception: # pragma: no cover - optional dependency path + vlm_middle_json_mkcontent = None + vlm_merge_para_with_text = None + HAS_VLM_BACKEND = False + # patches to mineru (to output bbox) def _parse_line_spans(para_block, page_idx): @@ -243,4 +250,5 @@ def vlm_make_blocks_to_content_list(para_block, img_buket_path, page_idx, page_s para_content['page_height'] = page_height return para_content -vlm_middle_json_mkcontent.make_blocks_to_content_list = vlm_make_blocks_to_content_list +if HAS_VLM_BACKEND: + vlm_middle_json_mkcontent.make_blocks_to_content_list = vlm_make_blocks_to_content_list diff --git a/lazyllm/tools/servers/mineru/mineru_server_module.py b/lazyllm/tools/servers/mineru/mineru_server_module.py index de5419f7a..99916321d 100644 --- a/lazyllm/tools/servers/mineru/mineru_server_module.py +++ b/lazyllm/tools/servers/mineru/mineru_server_module.py @@ -853,10 +853,11 @@ async def parse_pdf( # noqa: C901 try: results = {f: {} for f in files} + cache_enabled = bool(self.cache_manager and self.cache_manager.is_enabled) # honor use_cache for parsed outputs files_to_process = files - if use_cache and self.cache_manager.is_enabled: + if use_cache and cache_enabled: results, files_to_process = await self.cache_manager.check_parse_cache( files, results, effective_backend, return_md, return_content_list, table_enable, formula_enable, lang=lang, parse_method=parse_method @@ -872,7 +873,7 @@ async def parse_pdf( # noqa: C901 status_code=200, content={'result': results_list, 'unique_id': unique_id} ) - elif use_cache and not self.cache_manager.is_enabled: + elif use_cache and not cache_enabled: LOG.warning(f'[{req_id}] CACHE_DIR not set; use_cache ignored.') # 1) Convert to PDFs with caching support @@ -965,7 +966,7 @@ async def parse_pdf( # noqa: C901 results[src_path]['md_content'] = md_content if return_content_list and content_list: results[src_path]['content_list'] = content_list - if self.cache_manager.is_enabled: + if cache_enabled: await self.cache_manager.save_parse_result( hash_id, results[src_path], mode=effective_backend, table_enable=table_enable, formula_enable=formula_enable, diff --git a/lazyllm/tools/sql/sql_manager.py b/lazyllm/tools/sql/sql_manager.py index e37e438b9..685be6732 100644 --- a/lazyllm/tools/sql/sql_manager.py +++ b/lazyllm/tools/sql/sql_manager.py @@ -159,30 +159,72 @@ def _mysql_engine_kwargs(self) -> dict: kwargs.update({'pool_recycle': 3600}) return kwargs + @staticmethod + def _default_engine_kwargs() -> dict: + return { + 'pool_size': 10, + 'max_overflow': 20, + 'pool_pre_ping': True, + 'pool_recycle': 3600, + } + @staticmethod def _get_operational_error_code(error: OperationalError): args = getattr(getattr(error, 'orig', None), 'args', ()) return args[0] if args else None + @staticmethod + def _get_operational_error_pgcode(error: OperationalError): + orig = getattr(error, 'orig', None) + return getattr(orig, 'pgcode', None) or getattr(orig, 'sqlstate', None) + + def _is_database_not_found_error(self, error: OperationalError) -> bool: + if self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): + return self._get_operational_error_code(error) == 1049 + if self._db_type == 'postgresql': + if self._get_operational_error_pgcode(error) == '3D000': + return True + error_msg = str(getattr(error, 'orig', error)).lower() + return 'does not exist' in error_msg and 'database' in error_msg + return False + def _ensure_database_exists(self, conn_url: str): - if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb'): + if self._db_type not in ('mysql', 'mysql+pymysql', 'tidb', 'postgresql'): return - probe_engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + engine_kwargs = self._mysql_engine_kwargs() if self._db_type in ('mysql', 'mysql+pymysql', 'tidb') \ + else self._default_engine_kwargs() + probe_engine = sqlalchemy.create_engine(conn_url, **engine_kwargs) try: with probe_engine.connect(): return except OperationalError as e: - if self._get_operational_error_code(e) != 1049: + if not self._is_database_not_found_error(e): raise finally: probe_engine.dispose() - admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) + if self._db_type == 'postgresql': + admin_engine = sqlalchemy.create_engine( + self._gen_conn_url('postgres'), + isolation_level='AUTOCOMMIT', + **self._default_engine_kwargs() + ) + else: + admin_engine = sqlalchemy.create_engine(self._gen_conn_url(''), **self._mysql_engine_kwargs()) try: - escaped_db_name = self._db_name.replace('`', '``') with admin_engine.connect() as conn: - conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) - conn.commit() + if self._db_type == 'postgresql': + exists = conn.execute( + sqlalchemy.text('SELECT 1 FROM pg_database WHERE datname = :db_name'), + {'db_name': self._db_name} + ).scalar() + if not exists: + escaped_db_name = self._db_name.replace('"', '""') + conn.execute(sqlalchemy.text(f'CREATE DATABASE "{escaped_db_name}"')) + else: + escaped_db_name = self._db_name.replace('`', '``') + conn.execute(sqlalchemy.text(f'CREATE DATABASE IF NOT EXISTS `{escaped_db_name}`')) + conn.commit() finally: admin_engine.dispose() @@ -205,14 +247,11 @@ def engine(self): elif self._db_type in ('mysql', 'mysql+pymysql', 'tidb'): self._ensure_database_exists(conn_url) self._engine = sqlalchemy.create_engine(conn_url, **self._mysql_engine_kwargs()) + elif self._db_type == 'postgresql': + self._ensure_database_exists(conn_url) + self._engine = sqlalchemy.create_engine(conn_url, **self._default_engine_kwargs()) else: - self._engine = sqlalchemy.create_engine( - conn_url, - pool_size=10, - max_overflow=20, - pool_pre_ping=True, - pool_recycle=3600 - ) + self._engine = sqlalchemy.create_engine(conn_url, **self._default_engine_kwargs()) return self._engine @property diff --git a/tests/basic_tests/RAG/test_doc_processor.py b/tests/basic_tests/RAG/test_doc_processor.py index c6980faf9..eefcada2b 100644 --- a/tests/basic_tests/RAG/test_doc_processor.py +++ b/tests/basic_tests/RAG/test_doc_processor.py @@ -6,10 +6,11 @@ import pytest from lazyllm.tools.rag.parsing_service import DocumentProcessor -from lazyllm.tools.rag.parsing_service.base import TaskStatus +from lazyllm.tools.rag.parsing_service.base import TaskStatus, TaskType +from lazyllm.tools.rag.parsing_service.worker import DocumentProcessorWorker from lazyllm import Document, Retriever -STATIC_STATUS = [TaskStatus.FINISHED.value, TaskStatus.FAILED.value, TaskStatus.CANCELED.value] +STATIC_STATUS = [TaskStatus.SUCCESS.value, TaskStatus.FAILED.value, TaskStatus.CANCELED.value] records = [] def post_func_sample(task_id: str, task_status: str, error_code: str = None, error_msg: str = None): @@ -132,6 +133,49 @@ def test_upload_doc(self): except requests.exceptions.RequestException as e: self.fail(f'Request failed: {e}') + +class _FakeFinishedTaskQueue: + def __init__(self): + self.items = [] + + def enqueue(self, **kwargs): + self.items.append(kwargs) + + +def test_worker_callback_filters_allow_all_by_default(): + impl = DocumentProcessorWorker._Impl() + queue = _FakeFinishedTaskQueue() + impl._lazy_init = lambda: None + impl._finished_task_queue = queue + + impl._enqueue_finished_task( + task_id='task-1', + task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.WORKING, + callback_url='http://callback.test', + ) + + assert [item['task_id'] for item in queue.items] == ['task-1'] + + +def test_worker_callback_filters_skip_non_matching_records(): + impl = DocumentProcessorWorker._Impl( + callback_task_statuses=[TaskStatus.SUCCESS.value], + callback_task_types=[TaskType.DOC_ADD.value], + ) + queue = _FakeFinishedTaskQueue() + impl._lazy_init = lambda: None + impl._finished_task_queue = queue + + impl._enqueue_finished_task(task_id='task-delete', task_type=TaskType.DOC_DELETE.value, + task_status=TaskStatus.SUCCESS) + impl._enqueue_finished_task(task_id='task-working', task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.WORKING) + impl._enqueue_finished_task(task_id='task-add-success', task_type=TaskType.DOC_ADD.value, + task_status=TaskStatus.SUCCESS) + + assert [item['task_id'] for item in queue.items] == ['task-add-success'] + @pytest.mark.order(2) def test_reparse(self): with open(self._file_path, 'w') as f: diff --git a/tests/basic_tests/RAG/test_doc_processor_transfer.py b/tests/basic_tests/RAG/test_doc_processor_transfer.py new file mode 100644 index 000000000..73adbed3c --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_processor_transfer.py @@ -0,0 +1,31 @@ +from lazyllm.tools.rag.parsing_service import DocumentProcessor +from lazyllm.tools.rag.parsing_service.base import AddDocRequest, FileInfo, TaskType, TransferParams + + +def test_resolve_add_task_type_returns_transfer_for_copy_request(): + request = AddDocRequest( + algo_id='general_algo', + kb_id='kb_source', + file_infos=[FileInfo( + file_path='/tmp/source.pdf', + doc_id='doc_source', + transfer_params=TransferParams( + mode='cp', + target_algo_id='general_algo', + target_doc_id='doc_target', + target_kb_id='kb_target', + ), + )], + ) + + assert DocumentProcessor._Impl._resolve_add_task_type(request) == TaskType.DOC_TRANSFER.value + + +def test_resolve_add_task_type_keeps_add_for_regular_upload(): + request = AddDocRequest( + algo_id='general_algo', + kb_id='kb_source', + file_infos=[FileInfo(file_path='/tmp/source.pdf', doc_id='doc_source')], + ) + + assert DocumentProcessor._Impl._resolve_add_task_type(request) == TaskType.DOC_ADD.value diff --git a/tests/basic_tests/RAG/test_doc_service_doc_manager.py b/tests/basic_tests/RAG/test_doc_service_doc_manager.py new file mode 100644 index 000000000..9ace9c6be --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_service_doc_manager.py @@ -0,0 +1,596 @@ +import os +import tempfile +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from uuid import uuid4 + +import pytest + +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, + CallbackEventType, + DeleteRequest, + DocServiceError, + DocStatus, + ReparseRequest, + SourceType, + TaskCallbackRequest, + TransferItem, + TransferRequest, + UploadRequest, +) +from lazyllm.tools.rag.doc_service.doc_manager import DocManager +from lazyllm.tools.rag.doc_service.parser_client import ParserClient +from lazyllm.tools.rag.parsing_service.base import TaskType +from lazyllm.tools.rag.utils import BaseResponse + + +class _ManagerHarness: + def __init__(self): + self._tmp_dir = tempfile.TemporaryDirectory(prefix='lazyllm_doc_service_manager_') + self.tmp_dir = self._tmp_dir.name + self.seed_path = self.make_file('seed.txt', 'seed content') + self.db_config = { + 'db_type': 'sqlite', + 'user': None, + 'password': None, + 'host': None, + 'port': None, + 'db_name': os.path.join(self.tmp_dir, 'doc_service_local.db'), + } + original_health = ParserClient.health + ParserClient.health = lambda self: BaseResponse(code=200, msg='success', data={'ok': True}) + try: + self.manager = DocManager(db_config=self.db_config, parser_url='http://parser.test') + finally: + ParserClient.health = original_health + self.pending_task_status = {} + self.cancel_calls = [] + self.delete_calls = [] + self.chunk_calls = [] + self.add_doc_calls = [] + self.chunk_response = BaseResponse(code=200, msg='success', data={'items': [], 'total': 0}) + self._patch_parser_client() + + def close(self): + self._tmp_dir.cleanup() + + def make_file(self, name: str, content: str): + file_path = os.path.join(self.tmp_dir, name) + with open(file_path, 'w', encoding='utf-8') as file: + file.write(content) + return file_path + + def finish_task(self, task_id: str, status: DocStatus = DocStatus.SUCCESS, callback_id: str = None): + return self.manager.on_task_callback(TaskCallbackRequest( + callback_id=callback_id or str(uuid4()), + task_id=task_id, + event_type=CallbackEventType.FINISH, + status=status, + )) + + def start_task(self, task_id: str, callback_id: str = None): + return self.manager.on_task_callback(TaskCallbackRequest( + callback_id=callback_id or str(uuid4()), + task_id=task_id, + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + )) + + def _queue_task(self, task_id: str, final_status: DocStatus): + self.pending_task_status[task_id] = final_status + + def _patch_parser_client(self): + def add_doc(task_id, algo_id, kb_id, doc_id, file_path, metadata=None, reparse_group=None, + callback_url=None, transfer_params=None): + self.add_doc_calls.append({ + 'task_id': task_id, + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'file_path': file_path, + 'metadata': metadata, + 'reparse_group': reparse_group, + 'callback_url': callback_url, + 'transfer_params': transfer_params, + }) + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def update_meta(task_id, algo_id, kb_id, doc_id, metadata=None, file_path=None, callback_url=None): + del doc_id, metadata, file_path, callback_url + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def delete_doc(task_id, algo_id, kb_id, doc_id, callback_url=None): + del callback_url + self.delete_calls.append({'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id, 'doc_id': doc_id}) + self._queue_task(task_id, DocStatus.SUCCESS) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'algo_id': algo_id, 'kb_id': kb_id}) + + def cancel_task(task_id): + self.cancel_calls.append(task_id) + return BaseResponse(code=200, msg='success', data={'task_id': task_id, 'cancel_status': True}) + + def list_doc_chunks(algo_id, kb_id, doc_id, group, offset, page_size): + self.chunk_calls.append({ + 'algo_id': algo_id, + 'kb_id': kb_id, + 'doc_id': doc_id, + 'group': group, + 'offset': offset, + 'page_size': page_size, + }) + return self.chunk_response + + self.manager._parser_client.add_doc = add_doc + self.manager._parser_client.update_meta = update_meta + self.manager._parser_client.delete_doc = delete_doc + self.manager._parser_client.cancel_task = cancel_task + self.manager._parser_client.list_doc_chunks = list_doc_chunks + self.manager._parser_client.list_algorithms = lambda: BaseResponse( + code=200, + msg='success', + data=[{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + ) + self.manager._parser_client.get_algorithm_groups = lambda algo_id: BaseResponse( + code=200, + msg='success', + data=[{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}] if algo_id == '__default__' else None, + ) + + +@pytest.fixture +def manager_harness(): + harness = _ManagerHarness() + try: + yield harness + finally: + harness.close() + + +def test_manager_run_idempotent_atomic(manager_harness): + started = [] + + def handler(): + started.append(time.time()) + time.sleep(0.2) + return {'task_id': str(uuid4())} + + with ThreadPoolExecutor(max_workers=2) as pool: + future = pool.submit(manager_harness.manager.run_idempotent, '/local/atomic', 'same-key', {'k': 1}, handler) + time.sleep(0.05) + with pytest.raises(DocServiceError) as exc: + manager_harness.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + result = future.result(timeout=2) + + assert exc.value.biz_code == 'E_IDEMPOTENCY_IN_PROGRESS' + replay = manager_harness.manager.run_idempotent('/local/atomic', 'same-key', {'k': 1}, handler) + assert len(started) == 1 + assert replay == result + + +def test_manager_upload_callback_and_doc_detail(manager_harness): + manager_harness.manager.create_kb('kb_upload', algo_id='__default__') + file_path = manager_harness.make_file('upload.txt', 'upload content') + + items = manager_harness.manager.upload(UploadRequest( + kb_id='kb_upload', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='upload-doc')], + )) + + task_id = items[0]['task_id'] + assert items[0]['accepted'] is True + assert items[0]['parse_status'] == DocStatus.WAITING.value + + manager_harness.finish_task(task_id) + + task = manager_harness.manager.get_task(task_id) + detail = manager_harness.manager.get_doc_detail('upload-doc') + + assert task.code == 200 + assert task.data['status'] == DocStatus.SUCCESS.value + assert detail['doc']['upload_status'] == DocStatus.SUCCESS.value + assert detail['snapshot']['status'] == DocStatus.SUCCESS.value + assert detail['latest_task']['task_id'] == task_id + + +def test_manager_cancel_waiting_add_updates_all_states(manager_harness): + manager_harness.manager.create_kb('kb_cancel', algo_id='__default__') + file_path = manager_harness.make_file('cancel.txt', 'cancel content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_cancel', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='cancel-doc')], + )) + task_id = upload[0]['task_id'] + + resp = manager_harness.manager.cancel_task(task_id) + snapshot = manager_harness.manager._get_parse_snapshot('cancel-doc', 'kb_cancel', '__default__') + doc = manager_harness.manager._get_doc('cancel-doc') + task = manager_harness.manager.get_task(task_id) + + assert resp.code == 200 + assert resp.data['status'] == DocStatus.CANCELED.value + assert task.data['status'] == DocStatus.CANCELED.value + assert snapshot['status'] == DocStatus.CANCELED.value + assert doc['upload_status'] == DocStatus.CANCELED.value + + +def test_manager_cancel_working_task_rejected(manager_harness): + manager_harness.manager.create_kb('kb_working', algo_id='__default__') + file_path = manager_harness.make_file('working.txt', 'working content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_working', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='working-doc')], + )) + task_id = upload[0]['task_id'] + manager_harness.start_task(task_id) + + resp = manager_harness.manager.cancel_task(task_id) + + assert resp.code == 409 + assert resp.data['status'] == DocStatus.WORKING.value + + +def test_manager_delete_waiting_add_uses_cancel_path(manager_harness): + manager_harness.manager.create_kb('kb_delete_waiting', algo_id='__default__') + file_path = manager_harness.make_file('delete_waiting.txt', 'delete waiting content') + + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_delete_waiting', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='delete-waiting-doc')], + )) + original_task_id = upload[0]['task_id'] + + items = manager_harness.manager.delete(DeleteRequest( + kb_id='kb_delete_waiting', + algo_id='__default__', + doc_ids=['delete-waiting-doc'], + )) + snapshot = manager_harness.manager._get_parse_snapshot('delete-waiting-doc', 'kb_delete_waiting', '__default__') + doc = manager_harness.manager._get_doc('delete-waiting-doc') + + assert items[0]['task_id'] == original_task_id + assert items[0]['status'] == DocStatus.CANCELED.value + assert manager_harness.cancel_calls == [original_task_id] + assert manager_harness.delete_calls == [] + assert snapshot['status'] == DocStatus.CANCELED.value + assert doc['upload_status'] == DocStatus.CANCELED.value + + +def test_manager_stale_callback_ignored_after_reparse(manager_harness): + manager_harness.manager.create_kb('kb_stale', algo_id='__default__') + file_path = manager_harness.make_file('stale.txt', 'stale content') + + uploaded = manager_harness.manager.upload(UploadRequest( + kb_id='kb_stale', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='stale-doc')], + )) + manager_harness.finish_task(uploaded[0]['task_id']) + + first_task_id = manager_harness.manager.reparse(ReparseRequest( + kb_id='kb_stale', + algo_id='__default__', + doc_ids=['stale-doc'], + ))[0] + second_task_id = manager_harness.manager.reparse(ReparseRequest( + kb_id='kb_stale', + algo_id='__default__', + doc_ids=['stale-doc'], + ))[0] + + stale_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='stale-callback', + task_id=first_task_id, + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + )) + + assert first_task_id != second_task_id + assert stale_resp['ignored_reason'] == 'stale_task_callback' + + +def test_manager_transfer_uses_target_doc_id_for_target_records(manager_harness): + manager_harness.manager.create_kb('kb_transfer_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_transfer_target', algo_id='__default__') + file_path = manager_harness.make_file('transfer.txt', 'transfer content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_transfer_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc', + target_doc_id='target-doc-copy', + source_kb_id='kb_transfer_source', + source_algo_id='__default__', + target_kb_id='kb_transfer_target', + target_algo_id='__default__', + mode='copy', + )])) + + task_id = items[0]['task_id'] + manager_harness.finish_task(task_id) + target_doc = manager_harness.manager._get_doc('target-doc-copy') + + assert items[0]['doc_id'] == 'source-doc' + assert items[0]['target_doc_id'] == 'target-doc-copy' + assert manager_harness.add_doc_calls[-1]['doc_id'] == 'source-doc' + assert manager_harness.add_doc_calls[-1]['transfer_params']['target_doc_id'] == 'target-doc-copy' + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + assert manager_harness.manager._has_kb_document('kb_transfer_target', 'target-doc-copy') is True + assert manager_harness.manager._has_kb_document('kb_transfer_target', 'source-doc') is False + assert target_doc['upload_status'] == DocStatus.SUCCESS.value + assert target_doc['filename'] == 'transfer.txt' + assert target_doc['meta'] == '{}' + assert target_doc['path'] == file_path + assert manager_harness.manager._has_kb_document('kb_transfer_source', 'source-doc') is True + + +def test_manager_transfer_same_kb_copy_reuses_source_path(manager_harness): + manager_harness.manager.create_kb('kb_same_transfer', algo_id='__default__') + file_path = manager_harness.make_file('same-transfer.txt', 'same transfer content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_same_transfer', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='same-source-doc')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='same-source-doc', + target_doc_id='same-target-doc', + source_kb_id='kb_same_transfer', + source_algo_id='__default__', + target_kb_id='kb_same_transfer', + target_algo_id='__default__', + mode='copy', + )])) + + assert items[0]['accepted'] is True + assert items[0]['status'] == DocStatus.WAITING.value + assert items[0]['target_file_path'] == file_path + assert manager_harness.add_doc_calls[-1]['doc_id'] == 'same-source-doc' + assert manager_harness.add_doc_calls[-1]['transfer_params']['target_doc_id'] == 'same-target-doc' + assert manager_harness.manager._has_kb_document('kb_same_transfer', 'same-source-doc') is True + assert manager_harness.manager._has_kb_document('kb_same_transfer', 'same-target-doc') is True + + +def test_manager_transfer_move_cleans_source_doc_with_target_doc_id(manager_harness): + manager_harness.manager.create_kb('kb_move_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_move_target', algo_id='__default__') + file_path = manager_harness.make_file('move.txt', 'move content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_move_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-move')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-move', + target_doc_id='target-doc-move', + source_kb_id='kb_move_source', + source_algo_id='__default__', + target_kb_id='kb_move_target', + target_algo_id='__default__', + mode='move', + )])) + + task_id = items[0]['task_id'] + manager_harness.finish_task(task_id) + + assert manager_harness.manager._has_kb_document('kb_move_source', 'source-doc-move') is False + assert manager_harness.manager._has_kb_document('kb_move_target', 'target-doc-move') is True + assert manager_harness.manager._get_doc('source-doc-move')['upload_status'] == DocStatus.DELETED.value + assert manager_harness.manager._get_doc('target-doc-move')['upload_status'] == DocStatus.SUCCESS.value + assert manager_harness.manager._get_parse_snapshot('source-doc-move', 'kb_move_source', '__default__') is None + + +def test_manager_transfer_target_fields_override_source_defaults(manager_harness): + manager_harness.manager.create_kb('kb_override_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_override_target', algo_id='__default__') + file_path = manager_harness.make_file('override.txt', 'override content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_override_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-override', metadata={'a': 1, 'b': 'keep'})], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-override', + target_doc_id='target-doc-override', + source_kb_id='kb_override_source', + source_algo_id='__default__', + target_kb_id='kb_override_target', + target_algo_id='__default__', + target_metadata={'a': 2, 'c': 'new'}, + target_filename='renamed.txt', + mode='copy', + )])) + + manager_harness.finish_task(items[0]['task_id']) + target_doc = manager_harness.manager._get_doc('target-doc-override') + + assert target_doc['filename'] == 'renamed.txt' + assert target_doc['meta'] == '{"a": 2, "b": "keep", "c": "new"}' + assert os.path.basename(target_doc['path']) == 'renamed.txt' + assert manager_harness.add_doc_calls[-1]['metadata'] == {'a': 2, 'b': 'keep', 'c': 'new'} + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + + +def test_manager_transfer_target_file_path_overrides_target_file_info(manager_harness): + manager_harness.manager.create_kb('kb_path_source', algo_id='__default__') + manager_harness.manager.create_kb('kb_path_target', algo_id='__default__') + file_path = manager_harness.make_file('path-source.txt', 'path source content') + upload = manager_harness.manager.upload(UploadRequest( + kb_id='kb_path_source', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='source-doc-path')], + )) + manager_harness.finish_task(upload[0]['task_id']) + + items = manager_harness.manager.transfer(TransferRequest(items=[TransferItem( + doc_id='source-doc-path', + target_doc_id='target-doc-path', + source_kb_id='kb_path_source', + source_algo_id='__default__', + target_kb_id='kb_path_target', + target_algo_id='__default__', + target_file_path='/virtual/target/renamed-from-path.txt', + mode='copy', + )])) + + manager_harness.finish_task(items[0]['task_id']) + target_doc = manager_harness.manager._get_doc('target-doc-path') + + assert items[0]['target_file_path'] == '/virtual/target/renamed-from-path.txt' + assert target_doc['filename'] == 'renamed-from-path.txt' + assert target_doc['path'] == '/virtual/target/renamed-from-path.txt' + assert manager_harness.add_doc_calls[-1]['file_path'] == file_path + + +def test_manager_list_chunks_forwards_true_pagination(manager_harness): + manager_harness.manager.create_kb('kb_chunks', algo_id='__default__') + file_path = manager_harness.make_file('chunks.txt', 'chunks content') + uploaded = manager_harness.manager.upload(UploadRequest( + kb_id='kb_chunks', + algo_id='__default__', + items=[AddFileItem(file_path=file_path, doc_id='chunks-doc')], + )) + manager_harness.finish_task(uploaded[0]['task_id']) + manager_harness.chunk_response = BaseResponse(code=200, msg='success', data={ + 'items': [{'uid': 'chunk-2'}, {'uid': 'chunk-3'}], + 'total': 7, + }) + + data = manager_harness.manager.list_chunks( + kb_id='kb_chunks', + doc_id='chunks-doc', + group='line', + algo_id='__default__', + page=2, + page_size=2, + ) + + assert manager_harness.chunk_calls == [{ + 'algo_id': '__default__', + 'kb_id': 'kb_chunks', + 'doc_id': 'chunks-doc', + 'group': 'line', + 'offset': 2, + 'page_size': 2, + }] + assert data['items'] == [{'uid': 'chunk-2'}, {'uid': 'chunk-3'}] + assert data['total'] == 7 + assert data['page'] == 2 + assert data['offset'] == 2 + + +def test_manager_callback_payload_fallback_and_delete_transition(manager_harness): + manager_harness.manager.create_kb('kb_callback', algo_id='__default__') + file_path = manager_harness.make_file('callback.txt', 'callback content') + manager_harness.manager._upsert_doc( + doc_id='callback-doc', + filename='callback.txt', + path=file_path, + metadata={'case': 'callback'}, + source_type=SourceType.EXTERNAL, + ) + manager_harness.manager._ensure_kb_document('kb_callback', 'callback-doc') + queued_at = manager_harness.manager._upsert_parse_snapshot( + doc_id='callback-doc', + kb_id='kb_callback', + algo_id='__default__', + status=DocStatus.DELETING, + task_type=TaskType.DOC_DELETE, + current_task_id='delete-task', + queued_at=datetime.now(), + )['queued_at'] + + start_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='delete-start', + task_id='delete-task', + event_type=CallbackEventType.START, + status=DocStatus.WORKING, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'callback-doc', + 'kb_id': 'kb_callback', + 'algo_id': '__default__', + }, + )) + start_snapshot = manager_harness.manager._get_parse_snapshot('callback-doc', 'kb_callback', '__default__') + finish_resp = manager_harness.manager.on_task_callback(TaskCallbackRequest( + callback_id='delete-finish', + task_id='delete-task', + event_type=CallbackEventType.FINISH, + status=DocStatus.SUCCESS, + payload={ + 'task_type': TaskType.DOC_DELETE.value, + 'doc_id': 'callback-doc', + 'kb_id': 'kb_callback', + 'algo_id': '__default__', + }, + )) + snapshot = manager_harness.manager._get_parse_snapshot('callback-doc', 'kb_callback', '__default__') + + assert start_resp['ack'] is True + assert finish_resp['ack'] is True + assert start_snapshot['queued_at'] == queued_at + assert snapshot['status'] == DocStatus.DELETED.value + assert manager_harness.manager._has_kb_document('kb_callback', 'callback-doc') is False + assert manager_harness.manager._get_doc('callback-doc')['upload_status'] == DocStatus.DELETED.value + + +def test_parser_client_algo_endpoint_fallback(): + client = ParserClient(parser_url='http://parser.test') + calls = [] + + def fake_request(method, path, params=None, **kwargs): + assert method == 'GET' + assert kwargs == {} + del params + calls.append(path) + if path == '/v1/algo/list': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/list': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'algo_id': '__default__', 'display_name': 'Default', 'description': 'desc'}], + } + if path == '/v1/algo/__default__/groups': + raise RuntimeError('parser http error: 404 missing route') + if path == '/algo/__default__/group/info': + return { + 'code': 200, + 'msg': 'success', + 'data': [{'name': 'line', 'type': 'chunk', 'display_name': 'Line'}], + } + raise AssertionError(path) + + client._request = fake_request + + algo_resp = client.list_algorithms() + group_resp = client.get_algorithm_groups('__default__') + + assert algo_resp.code == 200 + assert group_resp.code == 200 + assert calls == [ + '/v1/algo/list', + '/algo/list', + '/v1/algo/__default__/groups', + '/algo/__default__/group/info', + ] diff --git a/tests/basic_tests/RAG/test_doc_service_doc_server.py b/tests/basic_tests/RAG/test_doc_service_doc_server.py new file mode 100644 index 000000000..ee10740e8 --- /dev/null +++ b/tests/basic_tests/RAG/test_doc_service_doc_server.py @@ -0,0 +1,248 @@ +import asyncio +import json +import os +import tempfile + +import pytest +import requests +from pydantic import ValidationError + +from lazyllm.thirdparty import fastapi + +from lazyllm.tools.rag.doc_service.doc_server import DocServer +from lazyllm.tools.rag.doc_service.base import ( + AddFileItem, + CallbackEventType, + DocServiceError, + DocStatus, + KbUpdateRequest, + SourceType, + TaskCallbackRequest, + TaskCancelRequest, + UploadRequest, +) +from lazyllm.tools.rag.utils import BaseResponse + + +class _UploadFile: + def __init__(self, filename: str, content: bytes): + self.filename = filename + self._content = content + self._offset = 0 + + async def read(self, size: int = -1): + if self._offset >= len(self._content): + return b'' + if size is None or size < 0: + size = len(self._content) - self._offset + start = self._offset + end = min(len(self._content), start + size) + self._offset = end + return self._content[start:end] + + async def close(self): + return None + + +class _FakeManager: + def __init__(self): + self.run_calls = [] + self.upload_request = None + self.chunk_kwargs = None + self.cancel_response = BaseResponse( + code=200, msg='success', data={'task_id': 'task-1', 'cancel_status': True, 'status': 'CANCELED'} + ) + + def run_idempotent(self, endpoint, idempotency_key, payload, handler): + self.run_calls.append({ + 'endpoint': endpoint, + 'idempotency_key': idempotency_key, + 'payload': payload, + }) + return handler() + + def upload(self, request: UploadRequest): + self.upload_request = request + return [ + {'doc_id': item.doc_id or 'generated-doc', 'task_id': f'task-{idx}'} + for idx, item in enumerate(request.items) + ] + + def cancel_task(self, task_id: str): + if isinstance(self.cancel_response, Exception): + raise self.cancel_response + resp = self.cancel_response + if isinstance(resp, BaseResponse): + return resp + return BaseResponse.model_validate(resp) + + def list_chunks(self, **kwargs): + self.chunk_kwargs = kwargs + return {'items': [{'uid': 'chunk-1'}], 'total': 1, 'page': kwargs['page'], 'page_size': kwargs['page_size']} + + +def _decode_response(response): + assert isinstance(response, fastapi.responses.JSONResponse) + return json.loads(response.body.decode()) + + +@pytest.fixture +def server_impl(): + with tempfile.TemporaryDirectory(prefix='lazyllm_doc_server_unit_') as temp_dir: + impl = DocServer._Impl(storage_dir=temp_dir, parser_url='http://parser.test') + impl._manager = _FakeManager() + impl._lazy_init = lambda: None + yield impl + + +def test_build_update_kb_payload_distinguishes_omitted_and_null(): + keep_req = KbUpdateRequest(display_name='Renamed', idempotency_key='kb-update-idem') + clear_req = KbUpdateRequest(display_name='Renamed', owner_id=None, idempotency_key='kb-update-idem') + + keep_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', keep_req) + clear_payload = DocServer._Impl._build_update_kb_payload('kb_local_idem', clear_req) + + assert keep_payload != clear_payload + assert keep_payload['explicit_fields'] == ['display_name', 'idempotency_key'] + assert clear_payload['explicit_fields'] == ['display_name', 'idempotency_key', 'owner_id'] + + +def test_normalize_task_callback_supports_legacy_fields(): + callback = DocServer._Impl._normalize_task_callback({ + 'task_id': 'task-1', + 'task_type': 'DOC_ADD', + 'doc_id': 'doc-1', + 'kb_id': 'kb-1', + 'algo_id': '__default__', + 'task_status': 'SUCCESS', + }) + + assert isinstance(callback, TaskCallbackRequest) + assert callback.event_type == CallbackEventType.FINISH + assert callback.status == DocStatus.SUCCESS + assert callback.payload == { + 'task_type': 'DOC_ADD', + 'doc_id': 'doc-1', + 'kb_id': 'kb-1', + 'algo_id': '__default__', + } + + +def test_run_wraps_doc_service_error(): + response = DocServer._Impl._response(data={'ok': True}) + assert _decode_response(response)['data'] == {'ok': True} + + impl = DocServer._Impl(storage_dir='.', parser_url='http://parser.test') + wrapped = impl._run(lambda: (_ for _ in ()).throw(DocServiceError('E_INVALID_PARAM', 'bad req', {'x': 1}))) + body = _decode_response(wrapped) + + assert body['code'] == 400 + assert body['msg'] == 'bad req' + assert body['data']['biz_code'] == 'E_INVALID_PARAM' + assert body['data']['x'] == 1 + + +def test_parser_url_returns_none_when_remote_endpoint_unavailable(monkeypatch): + server = object.__new__(DocServer) + server._raw_impl = None + server._impl = type('FakeImpl', (), {'_url': 'http://127.0.0.1:19002/generate'})() + + def _raise(*args, **kwargs): + raise requests.ConnectionError('missing endpoint') + + monkeypatch.setattr(requests, 'get', _raise) + + assert server.parser_url is None + + +def test_task_cancel_request_requires_task_id(): + with pytest.raises(ValidationError): + TaskCancelRequest.model_validate({}) + + +def test_cancel_task_http_maps_conflict(server_impl): + server_impl._manager.cancel_response = BaseResponse( + code=409, + msg='task cannot be canceled', + data={'task_id': 'task-1', 'cancel_status': False, 'status': 'WORKING'}, + ) + + response = server_impl.cancel_task(TaskCancelRequest(task_id='task-1', idempotency_key='idem')) + body = _decode_response(response) + + assert server_impl._manager.run_calls[0]['endpoint'] == '/v1/tasks/cancel' + assert server_impl._manager.run_calls[0]['idempotency_key'] == 'idem' + assert body['code'] == 409 + assert body['data']['biz_code'] == 'E_STATE_CONFLICT' + assert body['data']['status'] == 'WORKING' + + +def test_list_chunks_forwards_pagination_to_manager(server_impl): + response = server_impl.list_chunks( + kb_id='kb-1', + doc_id='doc-1', + group='block', + algo_id='algo-1', + page=3, + page_size=4, + offset=8, + ) + body = _decode_response(response) + + assert server_impl._manager.chunk_kwargs == { + 'kb_id': 'kb-1', + 'doc_id': 'doc-1', + 'group': 'block', + 'algo_id': 'algo-1', + 'page': 3, + 'page_size': 4, + 'offset': 8, + } + assert body['data']['items'] == [{'uid': 'chunk-1'}] + assert body['data']['total'] == 1 + + +def test_upload_request_uses_idempotency_payload(server_impl): + file_path = os.path.join(server_impl._storage_dir, 'seed.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write('seed content') + request = UploadRequest( + kb_id='kb-upload', + algo_id='__default__', + source_type=SourceType.API, + idempotency_key='upload-idem', + items=[AddFileItem(file_path=file_path, doc_id='doc-seed')], + ) + + response = server_impl.upload_request(request) + body = _decode_response(response) + + assert server_impl._manager.run_calls[0]['endpoint'] == '/v1/docs/upload' + assert server_impl._manager.run_calls[0]['payload']['items'][0]['doc_id'] == 'doc-seed' + assert body['data']['items'][0]['doc_id'] == 'doc-seed' + + +def test_upload_http_saves_unique_files_and_only_first_doc_id(server_impl): + files = [ + _UploadFile('dup.txt', b'first content'), + _UploadFile('dup.txt', b'second content'), + ] + + response = asyncio.run(server_impl.upload( + files=files, + kb_id='kb-upload', + algo_id='__default__', + source_type=SourceType.API, + doc_id='doc-first', + idempotency_key='upload-http-idem', + )) + body = _decode_response(response) + upload_request = server_impl._manager.upload_request + + assert body['code'] == 200 + assert len(upload_request.items) == 2 + assert upload_request.items[0].doc_id == 'doc-first' + assert upload_request.items[1].doc_id is None + assert upload_request.items[0].file_path != upload_request.items[1].file_path + assert os.path.exists(upload_request.items[0].file_path) + assert os.path.exists(upload_request.items[1].file_path) diff --git a/tests/basic_tests/RAG/test_document.py b/tests/basic_tests/RAG/test_document.py index 649b74d68..16cb1bdd9 100644 --- a/tests/basic_tests/RAG/test_document.py +++ b/tests/basic_tests/RAG/test_document.py @@ -1,4 +1,5 @@ import lazyllm +import cloudpickle from lazyllm.tools.rag.doc_impl import DocImpl from lazyllm.tools.rag.store.store_base import LAZY_IMAGE_GROUP from lazyllm.tools.rag.transform import SentenceSplitter @@ -6,20 +7,16 @@ from lazyllm.tools.rag.doc_node import DocNode from lazyllm.tools.rag.global_metadata import RAG_DOC_PATH, RAG_DOC_ID from lazyllm.tools.rag import Document, Retriever, TransformArgs, AdaptiveTransform -from lazyllm.tools.rag.doc_manager import DocManager -from lazyllm.tools.rag.utils import DocListManager, gen_docid -from lazyllm.launcher import cleanup -from lazyllm import config -from unittest.mock import MagicMock -import unittest -import httpx import os import shutil -import io -import re -import json -import time import tempfile +import time +import unittest +from unittest.mock import MagicMock + +import lazyllm.tools.rag.document as document_module +from lazyllm.launcher import cleanup +from lazyllm.tools.rag.utils import gen_docid class TestDocImpl(unittest.TestCase): @@ -41,6 +38,14 @@ def tearDown(self): self.tmp_file_a.close() self.tmp_file_b.close() + def _build_root_nodes(self, input_files): + root_nodes = [] + for path in input_files: + node = DocNode(group=LAZY_ROOT_NAME, text=os.path.basename(path)) + node._global_metadata = {RAG_DOC_ID: gen_docid(path), RAG_DOC_PATH: path} + root_nodes.append(node) + return {LAZY_ROOT_NAME: root_nodes, LAZY_IMAGE_GROUP: []} + def test_create_node_group_default(self): self.doc_impl._create_builtin_node_group('MyChunk', transform=lambda x: ','.split(x)) self.doc_impl._lazy_init() @@ -85,19 +90,106 @@ def test_add_files(self): new_doc = DocNode(text='new dummy text', group=LAZY_ROOT_NAME) new_doc._global_metadata = {RAG_DOC_ID: gen_docid(self.tmp_file_b.name), RAG_DOC_PATH: self.tmp_file_b.name} self.mock_directory_reader.load_data.return_value = {LAZY_ROOT_NAME: [new_doc], LAZY_IMAGE_GROUP: []} - self.doc_impl._add_doc_to_store([self.tmp_file_b.name]) + self.doc_impl._processor.add_doc([self.tmp_file_b.name]) assert len(self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME)) == 2 + def test_dataset_path_sync_without_doc_list_manager(self): + self.mock_embed.return_value = [0.1, 0.2, 0.3] + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, 'test.txt') + with open(file_path, 'w') as file: + file.write('local dataset path') + + doc_impl = DocImpl(embed=self.mock_embed, dataset_path=temp_dir, enable_path_monitoring=False) + doc_impl._reader = MagicMock() + doc_impl._reader.load_data.side_effect = ( + lambda input_files, metadatas, split_nodes_by_type=True: self._build_root_nodes(input_files) + ) + doc_impl._lazy_init() + + nodes = doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) + assert len(nodes) == 1 + assert nodes[0].global_metadata[RAG_DOC_ID] == gen_docid(file_path) + assert nodes[0].global_metadata[RAG_DOC_PATH] == file_path + assert not hasattr(doc_impl, '_dlm') + + def test_dataset_path_monitor_adds_and_removes_files(self): + self.mock_embed.return_value = [0.1, 0.2, 0.3] + with tempfile.TemporaryDirectory() as temp_dir: + file_a = os.path.join(temp_dir, 'a.txt') + with open(file_a, 'w') as file: + file.write('a') + + doc_impl = DocImpl(embed=self.mock_embed, dataset_path=temp_dir, enable_path_monitoring=True) + doc_impl._reader = MagicMock() + doc_impl._reader.load_data.side_effect = ( + lambda input_files, metadatas, split_nodes_by_type=True: self._build_root_nodes(input_files) + ) + doc_impl._local_monitor_interval = 0.1 + doc_impl._lazy_init() + + def wait_for_doc_ids(expected_ids): + deadline = time.time() + 3 + while time.time() < deadline: + nodes = doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) + current_ids = {node.global_metadata[RAG_DOC_ID] for node in nodes} + if current_ids == expected_ids: + return + time.sleep(0.1) + raise AssertionError(f'Expected doc ids {expected_ids}, got {current_ids}') + + try: + wait_for_doc_ids({gen_docid(file_a)}) + + file_b = os.path.join(temp_dir, 'b.txt') + with open(file_b, 'w') as file: + file.write('b') + wait_for_doc_ids({gen_docid(file_a), gen_docid(file_b)}) + + os.remove(file_a) + wait_for_doc_ids({gen_docid(file_b)}) + finally: + doc_impl._local_monitor_continue = False + if doc_impl._local_monitor_thread: + doc_impl._local_monitor_thread.join(timeout=1) + + def test_doc_impl_can_be_pickled_before_lazy_init(self): + doc_impl = DocImpl(embed=self.mock_embed, doc_files=[self.tmp_file_a.name]) + serialized = cloudpickle.dumps(doc_impl) + restored = cloudpickle.loads(serialized) + + assert restored._local_monitor_lock is not None + assert restored._local_monitor_thread is None + class TestDocument(unittest.TestCase): @classmethod def tearDownClass(cls): cleanup() + def setUp(self): + self._temp_dirs = [] + + def tearDown(self): + for temp_dir in self._temp_dirs: + shutil.rmtree(temp_dir, ignore_errors=True) + + def _build_dataset(self, text: str = None) -> str: + temp_dir = tempfile.mkdtemp(prefix='lazyllm_document_') + self._temp_dirs.append(temp_dir) + file_path = os.path.join(temp_dir, 'rag.txt') + with open(file_path, 'w', encoding='utf-8') as file: + file.write(text or '\n'.join( + f'第{i}段:何为天道?人法地,地法天,天法道,道法自然。什么是道?道可道,非常道。' + for i in range(1, 241) + )) + return temp_dir + def test_register_global_and_local(self): Document.create_node_group('Chunk1', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) Document.create_node_group('Chunk2', transform=TransformArgs( f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) - doc1, doc2 = Document('rag_master'), Document('rag_master') + dataset_path = self._build_dataset() + doc1, doc2 = Document(dataset_path), Document(dataset_path) doc2.create_node_group('Chunk2', transform=dict( f=SentenceSplitter, kwargs=dict(chunk_size=128, chunk_overlap=10))) doc2.create_node_group('Chunk3', trans_node=True, @@ -125,8 +217,14 @@ def test_register_global_and_local(self): assert isinstance(r[0], DocNode) def test_create_document(self): - Document('rag_master') - Document('rag_master/') + dataset_path = self._build_dataset() + Document(dataset_path) + Document(dataset_path + os.sep) + + def test_dataset_path_enables_monitoring_by_default_without_manager(self): + doc = Document(self._build_dataset()) + assert doc._manager._enable_path_monitoring is True + assert doc._impl._dataset_path == doc._manager._origin_path def test_register_with_pattern(self): Document.create_node_group('AdaptiveChunk1', transform=[ @@ -136,7 +234,7 @@ def test_register_with_pattern(self): dict(f=SentenceSplitter, pattern=(lambda x: x.endswith('.txt')), kwargs=dict(chunk_size=512, chunk_overlap=50)), TransformArgs(f=SentenceSplitter, pattern=None, kwargs=dict(chunk_size=256, chunk_overlap=25))])) - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc._impl._lazy_init() retriever = Retriever(doc, 'AdaptiveChunk1', similarity='bm25', topk=2) retriever('什么是道') @@ -144,7 +242,7 @@ def test_register_with_pattern(self): retriever('什么是道') def test_create_node_group_with_ref(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) # Create parent node group doc.create_node_group('parent_chunk', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) # Create ref node group under parent @@ -169,7 +267,7 @@ def transform_with_ref(text, ref): assert 'doc_test_ref' in ref_nodes[0].text def test_create_node_group_with_invalid_ref(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) # Create two independent node groups doc.create_node_group('group_a', transform=SentenceSplitter, chunk_size=512, chunk_overlap=50) doc.create_node_group('group_b', transform=SentenceSplitter, chunk_size=256, chunk_overlap=25) @@ -190,7 +288,7 @@ def test_find(self): # root --- CoarseChunk < /- chunk21 # \ \- chunk2 < # \- FineChunk \- chunk22 - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) doc.create_node_group('chunk11', parent='chunk1', @@ -215,22 +313,55 @@ def _test_impl(group, target): _test_impl(group, target) def test_doc_web_module(self): - import time - import requests - doc = Document('rag_master', manager='ui') - doc.create_kb_group(name='test_group') - doc2 = Document('rag_master', manager=doc.manager, name='test_group2') - doc.start() - time.sleep(4) - url = doc._manager._docweb.url - response = requests.get(url) - assert response.status_code == 200 - assert doc2._curr_group == 'test_group2' - assert doc2.manager == doc.manager - doc.stop() + dataset_path = self._build_dataset() + doc = Document(dataset_path, manager=True, create_ui=True) + try: + doc.start() + doc.create_kb_group(name='test_group') + doc2 = Document(dataset_path, manager=doc.manager, name='test_group2') + assert hasattr(doc._manager, '_docweb') + assert doc2._curr_group == 'test_group2' + assert doc2.manager == doc.manager + finally: + doc.stop() + + def test_manager_ui_remains_compatible(self): + dataset_path = self._build_dataset() + doc = Document(dataset_path, manager='ui') + try: + doc.start() + assert hasattr(doc._manager, '_docweb') + assert doc._manager._enable_path_monitoring is False + finally: + doc.stop() + + def test_create_ui_requires_doc_server(self): + with self.assertRaisesRegex(ValueError, 'requires an available DocServer'): + Document(self._build_dataset(), create_ui=True) + + def test_document_processor_manager_constraints(self): + dataset_path = self._build_dataset() + processor = document_module.DocumentProcessor(url='http://127.0.0.1:9966') + + with self.assertRaisesRegex(ValueError, 'store_conf'): + Document(dataset_path, manager=processor) + with self.assertRaisesRegex(ValueError, 'pure local map store'): + Document(dataset_path, manager=processor, store_conf={'type': 'map'}) + + doc = Document( + dataset_path, + manager=processor, + store_conf={'type': 'milvus', 'kwargs': {'uri': 'http://localhost:19530'}}, + ) + try: + assert doc._manager._enable_path_monitoring is False + assert doc._impl._enable_path_monitoring is False + assert doc._impl._dataset_path == doc._manager._origin_path + finally: + doc.stop() def test_get_nodes(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=256, chunk_overlap=25))) doc.activate_groups(groups=['chunk1']) @@ -240,7 +371,7 @@ def test_get_nodes(self): assert n.number == 2 def test_get_window_nodes(self): - doc = Document('rag_master') + doc = Document(self._build_dataset()) doc.create_node_group('chunk1', parent=Document.CoarseChunk, transform=dict(f=SentenceSplitter, kwargs=dict(chunk_size=128, chunk_overlap=12))) doc.activate_groups(groups=['chunk1']) @@ -267,114 +398,3 @@ def test_get_window_nodes(self): assert len(window) == 5 assert window == sorted(window, key=lambda n: n.number) assert all(n.number in [1, 2, 3, 4, 5] for n in window) - -class TmpDir: - def __init__(self): - self.root_dir = os.path.expanduser(os.path.join(config['home'], 'rag_for_document_ut')) - self.rag_dir = os.path.join(self.root_dir, 'rag_master') - os.makedirs(self.rag_dir, exist_ok=True) - - def __del__(self): - shutil.rmtree(self.root_dir) - -class TestDocumentServer(unittest.TestCase): - def setUp(self): - self.dir = TmpDir() - self.dlm = DocListManager(path=self.dir.rag_dir, name=None, enable_path_monitoring=False) - - self.doc_impl = DocImpl(embed=MagicMock(), dlm=self.dlm) - self.doc_impl._lazy_init() - - doc_manager = DocManager(self.dlm) - self.server = lazyllm.ServerModule(doc_manager) - - self.server.start() - - url_pattern = r'(http://\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}:\d+)' - self.doc_server_addr = re.findall(url_pattern, self.server._url)[0] - self.time_sleep = 30 - - def test_delete_files_in_store(self): - files = [('files', ('test1.txt', io.BytesIO(b'John\'s house is in Beijing'), 'text/palin')), - ('files', ('test2.txt', io.BytesIO(b'John\'s house is in Shanghai'), 'text/plain'))] - metadatas = [{'comment': 'comment1'}, {'signature': 'signature2'}] - params = dict(override='true', metadatas=json.dumps(metadatas)) - - url = f'{self.doc_server_addr}/upload_files' - response = httpx.post(url, params=params, files=files, timeout=10) - assert response.status_code == 200 and response.json().get('code') == 200, response.json() - ids = response.json().get('data')[0] - lazyllm.LOG.info(f'debug!!! ids -> {ids}') - assert len(ids) == 2 - - time.sleep(self.time_sleep) # waiting for worker thread to update newly uploaded files - - # make sure that ids are written into the store - nodes = self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) - doc_ids = [] - doc_file_paths = [] - doc_metadatas = [] - test1_docid = None - test2_docid = None - for node in nodes: - doc_ids.append(node.global_metadata[RAG_DOC_ID]) - doc_file_paths.append(node.global_metadata.get(RAG_DOC_PATH, '')) - doc_metadatas.append(node.global_metadata) - if 'test1' in node.global_metadata.get(RAG_DOC_PATH, ''): - test1_docid = node.global_metadata[RAG_DOC_ID] - elif 'test2' in node.global_metadata.get(RAG_DOC_PATH, ''): - test2_docid = node.global_metadata[RAG_DOC_ID] - lazyllm.LOG.info(f'debug!!! doc_ids -> {doc_ids}\n') - lazyllm.LOG.info(f'debug!!! doc_file_paths -> {doc_file_paths}\n') - lazyllm.LOG.info(f'debug!!! doc_metadatas -> {doc_metadatas}\n') - assert test1_docid and test2_docid - assert set(doc_ids) == set(ids) - - url = f'{self.doc_server_addr}/delete_files' - response = httpx.post(url, json=dict(file_ids=[test1_docid])) - assert response.status_code == 200 and response.json().get('code') == 200 - - time.sleep(self.time_sleep) # waiting for worker thread to delete files - - nodes = self.doc_impl.store.get_nodes(group=LAZY_ROOT_NAME) - assert len(nodes) == 1 - assert nodes[0].global_metadata[RAG_DOC_ID] == test2_docid - cur_meta_dict = nodes[0].global_metadata - - url = f'{self.doc_server_addr}/add_metadata' - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'title2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - lazyllm.LOG.info(f'debug!!! cur_meta_dict -> {cur_meta_dict}\n') - assert cur_meta_dict['title'] == 'title2' - - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'TITLE2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - lazyllm.LOG.info(f'debug!!! cur_meta_dict -> {cur_meta_dict}\n') - assert cur_meta_dict['title'] == ['title2', 'TITLE2'] - - url = f'{self.doc_server_addr}/delete_metadata_item' - - response = httpx.post(url, json=dict(doc_ids=[test2_docid], kv_pair={'title': 'TITLE2'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - assert cur_meta_dict['title'] == ['title2'] - - url = f'{self.doc_server_addr}/reset_metadata' - response = httpx.post(url, json=dict(doc_ids=[test2_docid], - new_meta={'author': 'author2', 'signature': 'signature_new'})) - assert response.status_code == 200 and response.json().get('code') == 200 - time.sleep(self.time_sleep) - assert cur_meta_dict['signature'] == 'signature_new' and cur_meta_dict['author'] == 'author2' - - url = f'{self.doc_server_addr}/query_metadata' - response = httpx.post(url, json=dict(doc_id=test2_docid)) - - # make sure that only one file is left - response = httpx.get(f'{self.doc_server_addr}/list_files') - assert response.status_code == 200 and len(response.json().get('data')) == 1 - - def tearDown(self): - # Must clean up the server as all uploaded files will be deleted as they are in tmp dir - self.dlm.release() diff --git a/tests/basic_tests/RAG/test_document_store_upsert_failure.py b/tests/basic_tests/RAG/test_document_store_upsert_failure.py new file mode 100644 index 000000000..369ce4aed --- /dev/null +++ b/tests/basic_tests/RAG/test_document_store_upsert_failure.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock + +import pytest + +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.global_metadata import RAG_DOC_ID, RAG_KB_ID +from lazyllm.tools.rag.store.document_store import _DocumentStore + + +def test_update_nodes_raises_when_store_upsert_returns_false(): + document_store = _DocumentStore(algo_name='test_algo', store={'type': 'map'}) + document_store.activate_group('group1') + document_store.impl.upsert = MagicMock(return_value=False) + + node = DocNode( + uid='node1', + text='text1', + group='group1', + global_metadata={RAG_KB_ID: 'kb1', RAG_DOC_ID: 'doc1'}, + ) + + with pytest.raises(RuntimeError, match='Failed to upsert segments for group group1'): + document_store.update_nodes([node], copy=True) + + +def test_update_doc_meta_raises_when_store_upsert_returns_false(): + document_store = _DocumentStore(algo_name='test_algo', store={'type': 'map'}) + document_store.activate_group('group1') + + node = DocNode( + uid='node1', + text='text1', + group='group1', + global_metadata={RAG_KB_ID: 'kb1', RAG_DOC_ID: 'doc1'}, + ) + document_store.update_nodes([node], copy=True) + document_store.impl.upsert = MagicMock(return_value=False) + + with pytest.raises(RuntimeError, match='Failed to upsert segments for group group1'): + document_store.update_doc_meta(doc_id='doc1', metadata={'foo': 'bar'}, kb_id='kb1') diff --git a/tests/basic_tests/RAG/test_processor_cleanup.py b/tests/basic_tests/RAG/test_processor_cleanup.py new file mode 100644 index 000000000..e42d4fba8 --- /dev/null +++ b/tests/basic_tests/RAG/test_processor_cleanup.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from lazyllm.tools.rag.doc_node import DocNode +from lazyllm.tools.rag.global_metadata import RAG_DOC_ID, RAG_KB_ID +from lazyllm.tools.rag.parsing_service import _Processor +from lazyllm.tools.rag.store import LAZY_IMAGE_GROUP, LAZY_ROOT_NAME + + +def _make_root_node(doc_id: str, kb_id: str) -> DocNode: + return DocNode( + uid=f'{doc_id}-root', + text='root', + group=LAZY_ROOT_NAME, + global_metadata={RAG_DOC_ID: doc_id, RAG_KB_ID: kb_id}, + ) + + +def test_add_doc_cleans_partial_segments_and_schema_on_failure(): + store = MagicMock() + reader = MagicMock() + schema_extractor = MagicMock() + reader.load_data.return_value = { + LAZY_ROOT_NAME: [_make_root_node('doc1', 'kb1')], + LAZY_IMAGE_GROUP: [], + } + store.update_nodes.side_effect = RuntimeError('upsert failed') + + processor = _Processor( + algo_id='algo1', + store=store, + reader=reader, + node_groups={}, + schema_extractor=schema_extractor, + ) + try: + with pytest.raises(RuntimeError, match='upsert failed'): + processor.add_doc( + input_files=['/tmp/doc1.txt'], + ids=['doc1'], + metadatas=[{}], + kb_id='kb1', + ) + finally: + processor.close() + + store.remove_nodes.assert_called_once_with(doc_ids=['doc1'], kb_id='kb1') + schema_extractor._delete_extract_data.assert_called_once_with( + algo_id='algo1', + kb_id='kb1', + doc_ids=['doc1'], + ) + + +def test_transfer_failure_cleans_target_segments_only(): + store = MagicMock() + reader = MagicMock() + store.get_nodes.return_value = [_make_root_node('source-doc', 'source-kb')] + store.update_nodes.side_effect = RuntimeError('upsert failed') + + processor = _Processor( + algo_id='algo1', + store=store, + reader=reader, + node_groups={}, + ) + try: + with pytest.raises(RuntimeError, match='upsert failed'): + processor.add_doc( + input_files=['/tmp/source.txt'], + ids=['source-doc'], + metadatas=[{}], + kb_id='source-kb', + transfer_mode='cp', + target_kb_id='target-kb', + target_doc_ids=['target-doc'], + ) + finally: + processor.close() + + store.remove_nodes.assert_called_once_with(doc_ids=['target-doc'], kb_id='target-kb') + reader.load_data.assert_not_called() diff --git a/tests/basic_tests/RAG/test_transform.py b/tests/basic_tests/RAG/test_transform.py index 79244aee2..a216d566b 100644 --- a/tests/basic_tests/RAG/test_transform.py +++ b/tests/basic_tests/RAG/test_transform.py @@ -1176,6 +1176,23 @@ def test_get_metadata_size(self): metadata_size = splitter._get_metadata_size(node) assert metadata_size == 0 + def test_get_metadata_size_respects_excluded_metadata_keys(self): + splitter = _TextSplitterBase(chunk_size=200, overlap=10) + node = DocNode( + text='Hello, world! This is a test.', + metadata={ + 'file_name': 'test.pdf', + 'title': 'Section 1', + 'lines': [{'content': 'x' * 2000, 'page': 1}], + }, + ) + node.excluded_embed_metadata_keys = ['lines'] + node.excluded_llm_metadata_keys = ['lines'] + + metadata_size = splitter._get_metadata_size(node) + + assert metadata_size < 200 + def test_transform_returns_chunks(self, doc_node): splitter = _TextSplitterBase(chunk_size=20, overlap=10) chunks = splitter([doc_node]) diff --git a/tests/basic_tests/Tools/test_sql_manager.py b/tests/basic_tests/Tools/test_sql_manager.py new file mode 100644 index 000000000..fbd4b4680 --- /dev/null +++ b/tests/basic_tests/Tools/test_sql_manager.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock, patch + +import sqlalchemy +from sqlalchemy.exc import OperationalError + +from lazyllm.tools.sql.sql_manager import SqlManager + + +def _make_unknown_database_error(): + class _UnknownDatabaseError(Exception): + args = (1049, "Unknown database 'lazyllm_doc_task'") + + return OperationalError('SELECT 1', {}, _UnknownDatabaseError()) + + +def _make_conn_cm(conn): + cm = MagicMock() + cm.__enter__.return_value = conn + cm.__exit__.return_value = False + return cm + + +class TestSqlManager(object): + + @patch('lazyllm.tools.sql.sql_manager.sqlalchemy.create_engine') + def test_tidb_create_database_when_missing(self, mock_create_engine): + probe_engine = MagicMock() + probe_engine.connect.side_effect = _make_unknown_database_error() + + admin_conn = MagicMock() + admin_engine = MagicMock() + admin_engine.connect.return_value = _make_conn_cm(admin_conn) + + final_engine = MagicMock() + mock_create_engine.side_effect = [probe_engine, admin_engine, final_engine] + + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + assert sql_manager.engine is final_engine + assert mock_create_engine.call_args_list[1].args[0] == 'mysql+pymysql://root:pwd@127.0.0.1:4000/' + assert str(admin_conn.execute.call_args.args[0]) == 'CREATE DATABASE IF NOT EXISTS `lazyllm_doc_task`' + + @patch('lazyllm.tools.sql.sql_manager.sqlalchemy.create_engine') + def test_tidb_skip_create_database_when_exists(self, mock_create_engine): + probe_engine = MagicMock() + probe_engine.connect.return_value = _make_conn_cm(MagicMock()) + + final_engine = MagicMock() + mock_create_engine.side_effect = [probe_engine, final_engine] + + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + assert sql_manager.engine is final_engine + assert mock_create_engine.call_count == 2 + + def test_tidb_primary_key_string_uses_varchar(self): + sql_manager = SqlManager('tidb', 'root', 'pwd', '127.0.0.1', 4000, 'lazyllm_doc_task') + + primary_key_type = sql_manager._sql_type_for('string', is_primary_key=True) + normal_string_type = sql_manager._sql_type_for('string') + + assert isinstance(primary_key_type, sqlalchemy.String) + assert primary_key_type.length == 255 + assert normal_string_type is sqlalchemy.Text