diff --git a/algorithm/chat/__init__.py b/algorithm/chat/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/algorithm/chat/app/api/__init__.py b/algorithm/chat/app/api/__init__.py new file mode 100644 index 0000000..bab6b6b --- /dev/null +++ b/algorithm/chat/app/api/__init__.py @@ -0,0 +1,6 @@ +# 路由定义 (v1/v2) +from __future__ import annotations + +from chat.app.core.chat_server import create_app + +__all__ = ['create_app'] diff --git a/algorithm/chat/app/api/chat_routes.py b/algorithm/chat/app/api/chat_routes.py new file mode 100644 index 0000000..9044ed6 --- /dev/null +++ b/algorithm/chat/app/api/chat_routes.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Body, Request +from chat.config import DEFAULT_CHAT_DATASET +from chat.app.core.chat_service import handle_chat + +router = APIRouter() + + +@router.post('/api/chat', summary='与知识库对话') +@router.post('/api/chat/stream', summary='与知识库对话') +async def chat( + query: str = Body(..., description='用户问题'), # noqa: B008 + history: Optional[List[Dict[str, Any]]] = Body(default=None, description='历史对话(每项可含 role、content)'), # noqa: B008 + session_id: str = Body('session_id', description='会话 ID'), # noqa: B008 + filters: Optional[Dict[str, Any]] = Body(None, description='检索过滤条件'), # noqa: B008 + files: Optional[List[str]] = Body(None, description='上传临时文件'), # noqa: B008 + debug: Optional[bool] = Body(False, description='是否开启debug模式'), # noqa: B008 + reasoning: Optional[bool] = Body(False, description='是否开启推理'), # noqa: B008 + databases: Optional[List[Dict]] = Body([], description='关联数据库'), # noqa: B008 + dataset: Optional[str] = Body(DEFAULT_CHAT_DATASET, description='数据库名称'), # noqa: B008 + priority: Optional[int] = Body(None, description='请求优先级,用于vllm调度。数值越大优先级越高'), # noqa: B008 + *, + request: Request, +): + is_stream = request.url.path.endswith('/stream') + return await handle_chat( + query=query, + history=history, + session_id=session_id, + filters=filters, + files=files, + debug=debug, + reasoning=reasoning, + databases=databases, + dataset=dataset, + priority=priority, + is_stream=is_stream, + ) diff --git a/algorithm/chat/app/api/health_routes.py b/algorithm/chat/app/api/health_routes.py new file mode 100644 index 0000000..531d001 --- /dev/null +++ b/algorithm/chat/app/api/health_routes.py @@ -0,0 +1,22 @@ +import os + +import httpx +from fastapi import APIRouter + +router = APIRouter() + + +@router.get('/health', summary='Health check') +@router.get('/api/health', summary='Health check (API path)') +async def health(): + doc_url = os.getenv('LAZYRAG_DOCUMENT_SERVER_URL', 'http://localhost:8000') + check_url = doc_url.rstrip('/') + '/' + status = {'document_server_url': doc_url, 'document_server_reachable': None} + try: + async with httpx.AsyncClient(timeout=3.0) as client: + await client.get(check_url) + status['document_server_reachable'] = True + except Exception as e: + status['document_server_reachable'] = False + status['document_server_error'] = str(e) + return status diff --git a/algorithm/chat/app/chat.py b/algorithm/chat/app/chat.py new file mode 100644 index 0000000..95f92fe --- /dev/null +++ b/algorithm/chat/app/chat.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from chat.app.api import create_app + +app = create_app() + +if __name__ == '__main__': + import argparse + import uvicorn + + parser = argparse.ArgumentParser() + parser.add_argument('--host', type=str, default='0.0.0.0', help='listen host') + parser.add_argument('--port', type=int, default=8046, help='listen port') + args = parser.parse_args() + + uvicorn.run(app, host=args.host, port=args.port) diff --git a/algorithm/chat/app/core/chat_server.py b/algorithm/chat/app/core/chat_server.py new file mode 100644 index 0000000..db77280 --- /dev/null +++ b/algorithm/chat/app/core/chat_server.py @@ -0,0 +1,77 @@ +from __future__ import annotations +from typing import Any, Dict, Optional +from fastapi import FastAPI +from lazyllm import LOG, once_wrapper + +from chat.config import URL_MAP, SENSITIVE_WORDS_PATH, DEFAULT_CHAT_DATASET +from chat.pipelines.agentic import agentic_rag +from chat.pipelines.naive import get_ppl_naive +from chat.components.process.sensitive_filter import SensitiveFilter + + +def create_app() -> FastAPI: + """FastAPI 应用初始化与路由挂载;pipeline 在模块导入时由 ChatServer 注册。""" + app = FastAPI( + title='LazyLLM Chat API', + description='基于知识库的对话 API 服务', + version='1.0.0', + ) + from chat.app.api import chat_routes, health_routes + + app.include_router(health_routes.router) + app.include_router(chat_routes.router) + return app + + +class ChatServer: + def __init__(self): + self.startup_validated = False + self.startup_validation_error: Optional[str] = None + self._on_server_start() + + @once_wrapper + def _on_server_start(self): + try: + self.query_ppl: Dict[str, Any] = {} + self.query_ppl_stream: Dict[str, Any] = {} + self.query_ppl_reasoning = agentic_rag + self.sensitive_filter = SensitiveFilter(SENSITIVE_WORDS_PATH) + + if self.sensitive_filter.loaded: + LOG.info( + f'[ChatServer] [SENSITIVE_FILTER] Successfully loaded ' + f'{self.sensitive_filter.keyword_count} sensitive keywords' + ) + else: + LOG.warning('[ChatServer] [SENSITIVE_FILTER] Failed to load, filter disabled') + + if DEFAULT_CHAT_DATASET in URL_MAP: + self.get_query_pipeline(DEFAULT_CHAT_DATASET) + self.get_query_pipeline(DEFAULT_CHAT_DATASET, stream=True) + self.startup_validated = True + else: + self.startup_validation_error = ( + f'default dataset `{DEFAULT_CHAT_DATASET}` not found in URL_MAP' + ) + raise KeyError(self.startup_validation_error) + + LOG.info('[ChatServer] [SERVER_START]') + except Exception as exc: + self.startup_validated = False + self.startup_validation_error = str(exc) + LOG.exception('[ChatServer] [SERVER_START_ERROR]') + raise exc + + def has_dataset(self, dataset: str) -> bool: + return dataset in URL_MAP + + def get_query_pipeline(self, dataset: str, *, stream: bool = False) -> Any: + if dataset not in URL_MAP: + raise KeyError(f'dataset `{dataset}` not found in URL_MAP') + pipeline_map = self.query_ppl_stream if stream else self.query_ppl + if dataset not in pipeline_map: + pipeline_map[dataset] = get_ppl_naive(url=URL_MAP[dataset], stream=stream) + return pipeline_map[dataset] + + +chat_server = ChatServer() diff --git a/algorithm/chat/app/core/chat_service.py b/algorithm/chat/app/core/chat_service.py new file mode 100644 index 0000000..805b670 --- /dev/null +++ b/algorithm/chat/app/core/chat_service.py @@ -0,0 +1,217 @@ +from __future__ import annotations +import asyncio +import json +import time +import uuid +from typing import Any, Dict, List, Optional, Union +import lazyllm +from lazyllm import LOG +from fastapi.responses import StreamingResponse +from chat.config import (URL_MAP, RAG_MODE, MULTIMODAL_MODE, MAX_CONCURRENCY, + LAZYRAG_LLM_PRIORITY, SENSITIVE_FILTER_RESPONSE_TEXT) +from chat.utils.helpers import validate_and_resolve_files +from chat.app.core.chat_server import chat_server + + +rag_sem = asyncio.Semaphore(MAX_CONCURRENCY) + + +def _sse_line(payload: Dict[str, Any]) -> str: + return json.dumps(payload, ensure_ascii=False, default=str) + '\n\n' + + +def _resp(code: int, msg: str, data: Any, cost: float) -> Dict[str, Any]: + return {'code': code, 'msg': msg, 'data': data, 'cost': cost} + + +def check_sensitive_content( + query: str, session_id: str, start_time: float +) -> Optional[Dict[str, Any]]: + if not chat_server.sensitive_filter.loaded: + return None + has_sensitive, sensitive_word = chat_server.sensitive_filter.check(query) + if has_sensitive: + cost = round(time.time() - start_time, 3) + LOG.warning( + f'[ChatServer] [SENSITIVE_FILTER_BLOCKED] [query={query[:50]}...] ' + f'[sensitive_word={sensitive_word}] [session_id={session_id}]' + ) + return _resp( + 200, + 'success', + { + 'think': None, + 'text': SENSITIVE_FILTER_RESPONSE_TEXT, + 'sources': [], + }, + cost, + ) + return None + + +def build_query_params(query: str, history: Optional[List[Dict[str, Any]]], + filters: Optional[Dict[str, Any]], other_files: List[str], + databases: Optional[List[Dict[str, Any]]], debug: bool, + image_files: List[str], priority: Optional[int]) -> Dict[str, Any]: + hist = [ + { + 'role': str(h.get('role', 'assistant')), + 'content': str(h.get('content', '')), + } + for h in (history or []) + if isinstance(h, dict) + ] + return { + 'query': query, 'history': hist, 'filters': filters if RAG_MODE and filters else {}, + 'files': other_files, 'image_files': image_files if MULTIMODAL_MODE and image_files else [], + 'debug': debug, 'databases': databases if RAG_MODE and databases else [], 'priority': priority + } + + +def log_chat_request(query: str, session_id: str, filters: Optional[Dict[str, Any]], + other_files: List[str], databases: Optional[List[Dict[str, Any]]], + image_files: List[str], cost: float, + response: Any = None, log_type: str = 'KB_CHAT') -> None: + databases_str = json.dumps(databases, ensure_ascii=False) if databases else [] + response_str = response if response is not None else None + LOG.info( + f'[ChatServer] [{log_type}] [query={query}] [session_id={session_id}] ' + f'[filters={filters}] [files={other_files}] [image_files={image_files}] ' + f'[databases={databases_str}] [cost={cost}] [response={response_str}]' + ) + + +async def handle_chat(query: str, history: Optional[List[Dict[str, Any]]], + session_id: str, filters: Optional[Dict[str, Any]], + files: Optional[List[str]], debug: Optional[bool], reasoning: Optional[bool], + databases: Optional[List[Dict[str, Any]]], dataset: Optional[str], + priority: Optional[int], is_stream: bool) -> Union[Dict[str, Any], StreamingResponse]: + result = None + priority = LAZYRAG_LLM_PRIORITY if priority is None else priority + + if not chat_server.has_dataset(dataset): + return _resp(400, f'dataset {dataset} not found', None, 0.0) + + start_time = time.time() + sensitive_check_result = check_sensitive_content(query, session_id, start_time) + sid = f'{session_id}_{time.time()}_{uuid.uuid4().hex}' + log_tag = 'KB_CHAT_STREAM' if is_stream else 'KB_CHAT' + LOG.info(f'[ChatServer] [{log_tag}] [query={query}] [sid={sid}]') + + if not is_stream: + if sensitive_check_result: + return sensitive_check_result + + other_files, image_files = validate_and_resolve_files(files) + query_params = build_query_params( + query, history, filters, other_files, image_files, + debug or False, databases, priority + ) + + try: + async with rag_sem: + lazyllm.globals._init_sid(sid=sid) + lazyllm.locals._init_sid(sid=sid) + result = await _run_sync_ppl( + bool(reasoning), dataset, query_params, query, filters, priority + ) + cost = round(time.time() - start_time, 3) + return _resp(200, 'success', result, cost) + except Exception as exc: + LOG.exception(exc) + cost = round(time.time() - start_time, 3) + return _resp(500, f'chat service failed: {exc}', None, cost) + finally: + cost = round(time.time() - start_time, 3) + log_chat_request( + query, sid, filters, other_files, image_files, databases, cost, result + ) + else: + if sensitive_check_result: + + async def error_stream(): + yield _sse_line(sensitive_check_result) + yield _sse_line(_resp(200, 'success', {'status': 'FINISHED'}, 0.0)) + + return StreamingResponse(error_stream(), media_type='text/event-stream') + + first_frame_logged = False + other_files, image_files = validate_and_resolve_files(files) + collected_chunks: List[str] = [] + + query_params = build_query_params( + query, history, filters, other_files, image_files, False, databases, priority + ) + + stream_call = ( + (chat_server.query_ppl_reasoning, query_params, None, True) + if reasoning + else (chat_server.get_query_pipeline(dataset, stream=True), query_params) + ) + + async def event_stream(ppl, *args) -> Any: + nonlocal first_frame_logged + try: + async with rag_sem: + lazyllm.globals._init_sid(sid=sid) + lazyllm.locals._init_sid(sid=sid) + async_result = await asyncio.to_thread(ppl, *args) + async for chunk in async_result: + now = time.time() + if not first_frame_logged: + first_cost = round(now - start_time, 3) + LOG.info( + f'[ChatServer] [KB_CHAT_STREAM_FIRST_FRAME] ' + f'[query={query}] [session_id={session_id}] ' + f'[cost={first_cost}]' + ) + first_frame_logged = True + + chunk_str = ( + chunk + if isinstance(chunk, str) + else json.dumps(chunk, ensure_ascii=False) + ) + collected_chunks.append(chunk_str) + cost = round(now - start_time, 3) + yield _sse_line(_resp(200, 'success', chunk, cost)) + + except Exception as exc: + LOG.exception(exc) + collected_chunks.append(f'[EXCEPTION]: {str(exc)}') + final_resp = _resp( + 500, f'chat service failed: {exc}', {'status': 'FAILED'}, 0.0 + ) + else: + final_resp = _resp(200, 'success', {'status': 'FINISHED'}, 0.0) + + cost = round(time.time() - start_time, 3) + final_resp['cost'] = cost + yield _sse_line(final_resp) + + log_chat_request(query, sid, filters, other_files, image_files, databases, + cost, '\n'.join(collected_chunks), 'KB_CHAT_STREAM_FINISH') + + return StreamingResponse( + event_stream(*stream_call), media_type='text/event-stream' + ) + + +async def _run_sync_ppl(reasoning: bool, dataset: str, query_params: Dict[str, Any], + query: str, filters: Optional[Dict[str, Any]], priority: Any) -> Any: + if reasoning: + return await asyncio.to_thread( + chat_server.query_ppl_reasoning, + {'query': query}, + { + 'kb_search': { + 'filters': filters, + 'files': [], + 'stream': False, + 'priority': priority, + 'document_url': URL_MAP[dataset], + } + }, + False, + ) + return await asyncio.to_thread(chat_server.get_query_pipeline(dataset), query_params) diff --git a/algorithm/chat/chat.py b/algorithm/chat/chat.py deleted file mode 100644 index d9eb077..0000000 --- a/algorithm/chat/chat.py +++ /dev/null @@ -1,413 +0,0 @@ -from __future__ import annotations -import asyncio -import json -import os -import sys -import time -import uuid -from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple, TypeVar -from dotenv import load_dotenv -from fastapi import Body, FastAPI, HTTPException, Request -from fastapi.responses import StreamingResponse -from pydantic import BaseModel, Field -import lazyllm -from lazyllm import LOG, once_wrapper - -BASE_DIR = Path(__file__).resolve().parent.parent -sys.path.insert(0, str(BASE_DIR)) - -from chat.chat_pipelines.agentic import agentic_rag # noqa: E402 -from chat.chat_pipelines.naive import get_rag_ppl # noqa: E402 -from chat.modules.engineering.sensitive_filter import SensitiveFilter # noqa: E402 - -load_dotenv() -# --------------------------------------------------------------------------- -# 配置项与依赖注入 -# --------------------------------------------------------------------------- -MOUNT_BASE_DIR: str = os.getenv('LAZYLLM_MOUNT_DIR', '/data') -SENSITIVE_WORDS_PATH: str = os.getenv('SENSITIVE_WORDS_PATH', 'data/sensitive_words.txt') -_LAZYRAG_LLM_PRIORITY_ENV = os.getenv('LAZYRAG_LLM_PRIORITY') -LAZYRAG_LLM_PRIORITY = ( - int(_LAZYRAG_LLM_PRIORITY_ENV) - if _LAZYRAG_LLM_PRIORITY_ENV is not None and _LAZYRAG_LLM_PRIORITY_ENV.isdigit() - else 0 -) - -# 配置不同模式的开关 -RAG_MODE = os.getenv('RAG_MODE', 'True').lower() == 'true' -MULTIMODAL_MODE = os.getenv('MULTIMODAL_MODE', 'True').lower() == 'true' - -# --------------------------------------------------------------------------- -# 常量定义 -# --------------------------------------------------------------------------- -# 敏感词被拦截时的响应文本 -SENSITIVE_FILTER_RESPONSE_TEXT = '对不起,我还没有学会回答这个问题。如果你有其他问题,我非常乐意为你提供帮助。' -# 支持的图片扩展名 -IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg') -DEFAULT_ALGO_SERVICE_URL = os.getenv('LAZYRAG_ALGO_SERVICE_URL', 'http://lazyllm-algo:8000').rstrip('/') -DEFAULT_ALGO_DATASET_NAME = os.getenv('LAZYRAG_ALGO_DATASET_NAME', 'general_algo') -DEFAULT_CHAT_DATASET = os.getenv('LAZYRAG_DEFAULT_CHAT_DATASET', 'algo') - -# --------------------------------------------------------------------------- -# 临时方案 待jiahao重构doc服务 -# --------------------------------------------------------------------------- - -url_map: Dict[str, str] = { - 'algo': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', - 'default': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', - 'general_algo': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', - 'research_center': 'http://10.119.16.66:9003,research_center_0131_a', - 'quantum': 'http://10.119.16.66:9002,quantum_0131_a', - 'tyy': 'http://10.119.16.66:9007,tyy_0302', - 'cf': 'http://10.119.16.66:9005,cf_0304', - '3m': 'http://10.119.16.66:9006,threem_0303', - 'crag': 'http://10.119.16.66:9001,crag_0130_a', -} - - -def build_query_pipeline(url: str, *, stream: bool = False) -> Any: - return get_rag_ppl(url=url, stream=stream) - - -# --------------------------------------------------------------------------- -# Pydantic 响应模型 -# --------------------------------------------------------------------------- -M = TypeVar('M') - - -class BaseResponse(BaseModel): - code: int = Field(200, description='API status code') - msg: str = Field('success', description='API status message') - data: Optional[M] = Field(None, description='API data') - - class Config: - schema_extra = {'example': {'code': 200, 'msg': 'success', 'data': None}} - - -class History(BaseModel): - role: str = Field('assistant', description='消息来自哪个角色,user / assistant') - content: str = Field('', description='消息内容') - - -class ChatResponse(BaseResponse): - cost: float = Field(0.0, description='API cost time (seconds)') - - class Config: - schema_extra = { - 'example': { - 'code': 200, - 'msg': 'success', - 'data': None, - 'cost': 0.1, - } - } - - -# --------------------------------------------------------------------------- -# FastAPI 实例 -# --------------------------------------------------------------------------- -app = FastAPI( - title='LazyLLM Chat API', - description='基于知识库的对话 API 服务', - version='1.0.0', -) - - -# --------------------------------------------------------------------------- -# Server 封装 -# --------------------------------------------------------------------------- -class ChatServer: - def __init__(self): - self.startup_validated = False - self.startup_validation_error: Optional[str] = None - self._on_server_start() - - @once_wrapper - def _on_server_start(self): - try: - self.query_ppl: Dict[str, Any] = {} - self.query_ppl_stream: Dict[str, Any] = {} - self.query_ppl_reasoning = agentic_rag - self.sensitive_filter = SensitiveFilter(SENSITIVE_WORDS_PATH) - if self.sensitive_filter.loaded: - LOG.info( - f'[ChatServer] [SENSITIVE_FILTER] Successfully loaded ' - f'{self.sensitive_filter.keyword_count} sensitive keywords' - ) - else: - LOG.warning('[ChatServer] [SENSITIVE_FILTER] Failed to load, filter disabled') - - if DEFAULT_CHAT_DATASET in url_map: - self.get_query_pipeline(DEFAULT_CHAT_DATASET) - self.get_query_pipeline(DEFAULT_CHAT_DATASET, stream=True) - self.startup_validated = True - else: - self.startup_validation_error = ( - f'default dataset `{DEFAULT_CHAT_DATASET}` not found in url_map' - ) - raise KeyError(self.startup_validation_error) - - LOG.info('[ChatServer] [SERVER_START]') - except Exception as exc: - self.startup_validated = False - self.startup_validation_error = str(exc) - LOG.exception('[ChatServer] [SERVER_START_ERROR]') - raise exc - - def has_dataset(self, dataset: str) -> bool: - return dataset in url_map - - def get_query_pipeline(self, dataset: str, *, stream: bool = False) -> Any: - if dataset not in url_map: - raise KeyError(dataset) - pipeline_map = self.query_ppl_stream if stream else self.query_ppl - if dataset not in pipeline_map: - pipeline_map[dataset] = build_query_pipeline(url_map[dataset], stream=stream) - return pipeline_map[dataset] - - -chat_server = ChatServer() -MAX_CONCURRENCY = int(os.getenv('MAX_CONCURRENCY', 10)) -rag_sem = asyncio.Semaphore(MAX_CONCURRENCY) - - -# --------------------------------------------------------------------------- -# 工具函数 -# --------------------------------------------------------------------------- -def _validate_and_resolve_files(files: Optional[List[str]]) -> tuple[List[str], List[str]]: - if not files: - return [], [] - - resolved: List[str] = [] - for f in files: - real_path = f if os.path.isabs(f) else os.path.join(MOUNT_BASE_DIR, f) - if not (os.path.isfile(real_path) and os.access(real_path, os.R_OK)): - raise HTTPException(status_code=400, detail=f'File {real_path} is not accessible') - resolved.append(real_path) - - image_files = [p for p in resolved if p.lower().endswith(IMAGE_EXTENSIONS)] - other_files = [p for p in resolved if p not in image_files] - return other_files, image_files - - -def _check_sensitive_content(query: str, session_id: str, start_time: float) -> Optional[Tuple[float, str]]: - if not chat_server.sensitive_filter.loaded: - return None - has_sensitive, sensitive_word = chat_server.sensitive_filter.check(query) - if has_sensitive: - cost = round(time.time() - start_time, 3) - LOG.warning( - f'[ChatServer] [SENSITIVE_FILTER_BLOCKED] [query={query[:50]}...] ' - f'[sensitive_word={sensitive_word}] [session_id={session_id}]' - ) - return ChatResponse( - code=200, - msg='success', - data={ - 'think': None, - 'text': SENSITIVE_FILTER_RESPONSE_TEXT, - 'sources': [] - }, - cost=cost - ) - return None - - -def _build_query_params( - query: str, - history: List[History], - filters: Optional[Dict[str, Any]], - other_files: List[str], - image_files: List[str], - debug: bool, - databases: Optional[List[Dict[str, Any]]], - priority: Optional[int] -) -> Dict[str, Any]: - return { - 'query': query, - 'history': [h.model_dump() for h in history], - 'filters': filters if RAG_MODE and filters else {}, - 'files': other_files, - 'image_files': image_files if MULTIMODAL_MODE and image_files else [], - 'debug': debug, - 'databases': databases if RAG_MODE and databases else [], - 'priority': priority - } - - -def _log_chat_request( - query: str, - session_id: str, - filters: Optional[Dict[str, Any]], - other_files: List[str], - image_files: List[str], - databases: Optional[List[Dict[str, Any]]], - cost: float, - response: Any = None, - log_type: str = 'KB_CHAT' -) -> None: - databases_str = json.dumps(databases, ensure_ascii=False) if databases else [] - response_str = response if response is not None else None - LOG.info( - f'[ChatServer] [{log_type}] [query={query}] [session_id={session_id}] ' - f'[filters={filters}] [files={other_files}] [image_files={image_files}] ' - f'[databases={databases_str}] [cost={cost}] [response={response_str}]' - ) - - -@app.get('/health', summary='Health check') -@app.get('/api/health', summary='Health check (API path)') -async def health(): - doc_url = os.getenv('LAZYRAG_DOCUMENT_SERVER_URL', 'http://localhost:8000') - status = { - 'document_server_url': doc_url, - 'document_server_reachable': None, - 'default_dataset': DEFAULT_CHAT_DATASET, - 'chat_startup_validated': chat_server.startup_validated, - } - if chat_server.startup_validation_error: - status['chat_startup_error'] = chat_server.startup_validation_error - try: - import urllib.request - req = urllib.request.Request(doc_url.rstrip('/') + '/', method='GET') - urllib.request.urlopen(req, timeout=3) - status['document_server_reachable'] = True - except Exception as e: - status['document_server_reachable'] = False - status['document_server_error'] = str(e) - return status - - -@app.post('/api/chat', summary='与知识库对话') -@app.post('/api/chat/stream', summary='与知识库对话') -async def chat( - query: str = Body(..., description='用户问题'), # noqa: B008 - history: List[History] = Body(default=None, description='历史对话,可为 list 或省略(代理可能传为 {})'), # noqa: B008 - session_id: str = Body('session_id', description='会话 ID'), # noqa: B008 - filters: Optional[Dict[str, Any]] = Body(None, description='检索过滤条件'), # noqa: B008 - files: Optional[List[str]] = Body(None, description='上传临时文件'), # noqa: B008 - debug: Optional[bool] = Body(False, description='是否开启debug模式'), # noqa: B008 - reasoning: Optional[bool] = Body(False, description='是否开启推理'), # noqa: B008 - databases: Optional[List[Dict]] = Body([], description='关联数据库'), # noqa: B008 - dataset: Optional[str] = Body(DEFAULT_CHAT_DATASET, description='数据库名称'), # noqa: B008 临时方案,待jiahao重构doc服务 - priority: Optional[int] = Body( # noqa: B008 - None, - description='请求优先级,用于vllm调度。数值越大优先级越高,默认从环境变量LAZYRAG_LLM_PRIORITY读取', - ), - *, - request: Request, -) -> ChatResponse: - cost = 0.0 - result = None - is_stream = request.url.path.endswith('/stream') - priority = int(os.getenv('LAZYRAG_LLM_PRIORITY', '0')) if priority is None else priority - if not chat_server.has_dataset(dataset): - return ChatResponse(code=400, msg=f'dataset {dataset} not found', cost=0.0) - start_time = time.time() - sensitive_check_result = _check_sensitive_content(query, session_id, start_time) - sid = f'{session_id}_{time.time()}_{uuid.uuid4().hex}' - log_tag = 'KB_CHAT_STREAM' if is_stream else 'KB_CHAT' - LOG.info(f'[ChatServer] [{log_tag}] [query={query}] [sid={sid}]') - if not is_stream: - if sensitive_check_result: - return sensitive_check_result - other_files, image_files = _validate_and_resolve_files(files) - query_params = _build_query_params( - query, history, filters, other_files, image_files, debug, databases, priority - ) - try: - async with rag_sem: - lazyllm.globals._init_sid(sid=sid) - lazyllm.locals._init_sid(sid=sid) - if reasoning: - global_params = {'query': query} - tool_params = {'kb_search': {'filters': filters, 'files': [], - 'stream': False, 'priority': priority, - 'document_url': url_map[dataset]}} - result = await asyncio.to_thread(chat_server.query_ppl_reasoning, global_params, tool_params, False) - else: - result = await asyncio.to_thread(chat_server.get_query_pipeline(dataset), query_params) - - cost = round(time.time() - start_time, 3) - return ChatResponse(code=200, msg='success', data=result, cost=cost) - except Exception as exc: - LOG.exception(exc) - cost = round(time.time() - start_time, 3) - return ChatResponse(code=500, msg=f'chat service failed: {exc}', cost=cost) - finally: - cost = round(time.time() - start_time, 3) - _log_chat_request(query, sid, filters, other_files, image_files, databases, cost, result) - else: - if sensitive_check_result: - async def error_stream(): - yield sensitive_check_result.model_dump_json() + '\n\n' - finish_resp = ChatResponse(code=200, msg='success', data={'status': 'FINISHED'}) - yield finish_resp.model_dump_json() + '\n\n' - return StreamingResponse(error_stream(), media_type='text/event-stream') - - first_frame_logged = False # 标记是否已记录首帧耗时 - other_files, image_files = _validate_and_resolve_files(files) - collected_chunks: List[str] = [] - - query_params = _build_query_params( - query, history, filters, other_files, image_files, False, databases, priority - ) - - async def event_stream(ppl, *args, **kwargs) -> Any: - nonlocal first_frame_logged - try: - async with rag_sem: - lazyllm.globals._init_sid(sid=sid) - lazyllm.locals._init_sid(sid=sid) - async_result = await asyncio.to_thread(ppl, *args, **kwargs) - async for chunk in async_result: - now = time.time() - # ------ 记录首帧耗时 ---------------------------------- - if not first_frame_logged: - first_cost = round(now - start_time, 3) - LOG.info( - f'[ChatServer] [KB_CHAT_STREAM_FIRST_FRAME] ' - f'[query={query}] [session_id={session_id}] ' - f'[cost={first_cost}]' - ) - first_frame_logged = True - # --------------------------------------------------------- - - chunk_str = chunk if isinstance(chunk, str) else json.dumps(chunk, ensure_ascii=False) - collected_chunks.append(chunk_str) - cost = round(now - start_time, 3) - yield ChatResponse(code=200, msg='success', data=chunk, cost=cost).model_dump_json() + '\n\n' - - except Exception as exc: - LOG.exception(exc) - collected_chunks.append(f'[EXCEPTION]: {str(exc)}') - final_resp = ChatResponse(code=500, msg=f'chat service failed: {exc}', data={'status': 'FAILED'}) - else: - final_resp = ChatResponse(code=200, msg='success', data={'status': 'FINISHED'}) - - cost = round(time.time() - start_time, 3) - final_resp.cost = cost - yield final_resp.model_dump_json() + '\n\n' - - response_text = '\n'.join(collected_chunks) - _log_chat_request(query, sid, filters, other_files, image_files, databases, - cost, response_text, 'KB_CHAT_STREAM_FINISH') - - if reasoning: - return StreamingResponse(event_stream(chat_server.query_ppl_reasoning, query_params, None, True), - media_type='text/event-stream') - return StreamingResponse(event_stream(chat_server.get_query_pipeline(dataset, stream=True), query_params), - media_type='text/event-stream') - - -if __name__ == '__main__': - import argparse - import uvicorn - - parser = argparse.ArgumentParser() - parser.add_argument('--host', type=str, default='0.0.0.0', help='listen host') - parser.add_argument('--port', type=int, default=8046, help='listen port') - args = parser.parse_args() - - uvicorn.run(app, host=args.host, port=args.port) diff --git a/algorithm/chat/chat_pipelines/naive.py b/algorithm/chat/chat_pipelines/naive.py deleted file mode 100644 index 751678a..0000000 --- a/algorithm/chat/chat_pipelines/naive.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- -# flake8: noqa: E402 -import functools -import os -from pathlib import Path -import lazyllm -from typing import List -from lazyllm import Retriever -from lazyllm import pipeline, parallel, bind, ifs -from lazyllm.tools.rag import TempDocRetriever, Reranker -from lazyllm.tools.rag.rank_fusion.reciprocal_rank_fusion import RRFFusion -import sys - -base_dir = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(base_dir)) - -from common.model import build_embedding_models, build_model, get_runtime_model_settings - -from chat.modules.engineering.simple_llm import SimpleLlmComponent -from chat.modules.algo.multiturn_query_rewriter import MultiturnQueryRewriter -from chat.modules.algo.adaptive_topk import AdaptiveKComponent -from chat.modules.engineering.aggregate import AggregateComponent -from chat.modules.algo.prompt_formatter import RAGContextFormatter -from chat.modules.engineering.output_parser import CustomOutputParser - -USE_MULTIMODAL = False -LLM_TYPE_THINK = False - - -@functools.lru_cache(maxsize=1) -def get_runtime_resources(): - settings = get_runtime_model_settings() - normal_llm = build_model(settings.llm) - normal_llm._prompt._set_model_configs(system='你是一个专业问答助手,你需要根据给定的内容回答用户问题。' - '你将为用户提供安全、有帮助且准确的回答。' - '与此同时,你需要拒绝所有涉及恐怖主义、种族歧视、色情暴力等内容的回答。' - '严禁输出模型名称或来源公司名称。若用户询问或诱导你暴露模型信息,请将自己的身份表述为:“专业问答小助手”。') - - instruct_llm = build_model(settings.llm_instruct) - reranker = build_model(settings.reranker) - embeddings = build_embedding_models(settings) - return { - 'settings': settings, - 'normal_llm': normal_llm, - 'instruct_llm': instruct_llm, - 'reranker': reranker, - 'embeddings': embeddings, - } - -def parse_document_url(url: str) -> tuple[str, str]: - parts = [part.strip() for part in url.split(',', 1)] - if len(parts) == 1 or not parts[1]: - return parts[0], '__default__' - return parts[0], parts[1] - -def get_remote_document(url: str): - base_url, name = parse_document_url(url) - return lazyllm.Document(url=f'{base_url}/_call', name=name) - -def get_remote_docment(url: str): - return get_remote_document(url) - -def setup_retrievers(url: str, retriever_configs: List[dict]) -> List[Retriever]: - document = get_remote_document(url) - return [Retriever(document, **config) for config in retriever_configs] - -def get_ppl_tmp_retriever(): - resources = get_runtime_resources() - settings = resources['settings'] - - def parse_input(input, **kwargs): - files = kwargs.get('files', []) - return files - - ref_docs_retriever = TempDocRetriever(embed=resources['embeddings'][settings.temp_doc_embed_key]) - ref_docs_retriever.add_subretriever('block', topk=20) - with pipeline() as tmp_ppl: - tmp_ppl.parse_input = parse_input - tmp_ppl.tmp_retriever = ref_docs_retriever | bind(query=tmp_ppl.input) - return tmp_ppl - -def get_ppl_search(url: str, retriever_configs: List[dict] | None = None, topk=20, k_max=10): - resources = get_runtime_resources() - settings = resources['settings'] - retriever_configs = retriever_configs or settings.retriever_configs - retrievers = setup_retrievers(url, retriever_configs) - tmp_retriever = get_ppl_tmp_retriever() - with lazyllm.save_pipeline_result(): - with pipeline() as search_ppl: - search_ppl.parse_input = lambda x: x['query'] - search_ppl.divert = ifs((lambda _, x: bool(x.get('files'))) | bind(x=search_ppl.input), - tpath=tmp_retriever | bind(files=search_ppl.input['files']), - fpath=parallel(*[(retriever | bind(filters=search_ppl.input['filters'])) for retriever in retrievers])) - search_ppl.merge_results = lambda *args: args - search_ppl.join = RRFFusion(top_k=50) - search_ppl.reranker = Reranker('ModuleReranker', model=resources['reranker'], topk=topk) | bind( - query=search_ppl.input['query']) - search_ppl.adaptive_k = AdaptiveKComponent(bias=2, k_max=k_max, gap_tau=0.2) - return search_ppl - -def get_ppl_llm_generate(stream=False): - normal_llm = get_runtime_resources()['normal_llm'] - with lazyllm.save_pipeline_result(): - with pipeline() as ppl: - ppl.aggregate = AggregateComponent() - ppl.formatter = RAGContextFormatter() | bind(query=ppl.kwargs['query'], nodes=ppl.aggregate) - # ppl.answer = ifs((lambda _, stream: stream) | bind(stream=stream), - # tpath=StreamCallHelper(normal_llm) | bind(stream=stream, llm_chat_history=[], files=[], priority=1), - # fpath=normal_llm | bind(stream=stream, llm_chat_history=[], files=[], priority=1)) - ppl.answer = SimpleLlmComponent(llm=normal_llm) | \ - bind(stream=stream, llm_chat_history=[], files=[], priority=1) - ppl.parser = CustomOutputParser(llm_type_think=LLM_TYPE_THINK) | bind( - stream=stream, - recall_result=ppl.input, - aggregate=ppl.aggregate, - image_files=[], - debug=ppl.kwargs['debug']) - return ppl - - -def get_rag_ppl(url: str, retriever_configs: List[dict] | None = None, stream=False): - instruct_llm = get_runtime_resources()['instruct_llm'] - with lazyllm.save_pipeline_result(): - with pipeline() as rag_ppl: - rag_ppl.rewriter = ifs( - lambda x: x.get('history'), - tpath=MultiturnQueryRewriter(llm=instruct_llm) - | bind( - priority=rag_ppl.input['priority'], - has_appendix=bool(rag_ppl.input['image_files']) - or bool(rag_ppl.input['files']), - ), - fpath=lambda x: x, - ) - rag_ppl.search = get_ppl_search(url, retriever_configs) # TODO: 根据kb_id判断是否需要检索,依赖jiahao的doc服务 - rag_ppl.generate = get_ppl_llm_generate(stream=stream) | bind( - image_files=[], - query=rag_ppl.input['query'], - history=rag_ppl.input['history'], - debug=rag_ppl.input['debug'],) - return rag_ppl diff --git a/algorithm/chat/components/generate/__init__.py b/algorithm/chat/components/generate/__init__.py new file mode 100644 index 0000000..897bc1d --- /dev/null +++ b/algorithm/chat/components/generate/__init__.py @@ -0,0 +1,9 @@ +from chat.components.generate.aggregate import AggregateComponent +from chat.components.generate.prompt_formatter import RAGContextFormatter +from chat.components.generate.output_parser import CustomOutputParser + +__all__ = [ + 'AggregateComponent', + 'RAGContextFormatter', + 'CustomOutputParser', +] diff --git a/algorithm/chat/modules/engineering/aggregate.py b/algorithm/chat/components/generate/aggregate.py similarity index 100% rename from algorithm/chat/modules/engineering/aggregate.py rename to algorithm/chat/components/generate/aggregate.py diff --git a/algorithm/chat/modules/engineering/output_parser.py b/algorithm/chat/components/generate/output_parser.py similarity index 100% rename from algorithm/chat/modules/engineering/output_parser.py rename to algorithm/chat/components/generate/output_parser.py diff --git a/algorithm/chat/modules/algo/prompt_formatter.py b/algorithm/chat/components/generate/prompt_formatter.py similarity index 100% rename from algorithm/chat/modules/algo/prompt_formatter.py rename to algorithm/chat/components/generate/prompt_formatter.py diff --git a/algorithm/chat/components/process/__init__.py b/algorithm/chat/components/process/__init__.py new file mode 100644 index 0000000..51e5c2f --- /dev/null +++ b/algorithm/chat/components/process/__init__.py @@ -0,0 +1,11 @@ +from chat.components.process.sensitive_filter import SensitiveFilter +from chat.components.process.multiturn_query_rewriter import MultiturnQueryRewriter +from chat.components.process.context_expansion import ContextExpansionComponent +from chat.components.process.adaptive_topk import AdaptiveKComponent + +__all__ = [ + 'SensitiveFilter', + 'MultiturnQueryRewriter', + 'ContextExpansionComponent', + 'AdaptiveKComponent', +] diff --git a/algorithm/chat/modules/algo/adaptive_topk.py b/algorithm/chat/components/process/adaptive_topk.py similarity index 100% rename from algorithm/chat/modules/algo/adaptive_topk.py rename to algorithm/chat/components/process/adaptive_topk.py diff --git a/algorithm/chat/components/process/context_expansion.py b/algorithm/chat/components/process/context_expansion.py new file mode 100644 index 0000000..cad2a7f --- /dev/null +++ b/algorithm/chat/components/process/context_expansion.py @@ -0,0 +1,112 @@ +import time +from typing import List, Optional, Set, Tuple +from lazyllm import LOG, Document +from lazyllm.tools.rag import DocNode + +_RPC_RETRIES = 2 +_RPC_RETRY_DELAY = 0.3 + + +def _get_doc_id(node: DocNode) -> Optional[str]: + return (node.global_metadata or {}).get('docid') + + +def _estimate_tokens(text: str) -> int: + return max(1, len(text) // 4) + + +def _node_sort_key(node: DocNode) -> Tuple: + return (node.metadata.get('index') or 0, node.uid) + + +def _get_node_type(node: DocNode) -> Optional[str]: + try: + md = getattr(node, 'metadata', None) or {} + if isinstance(md, dict): + t = md.get('type') or md.get('node_type') + if isinstance(t, str) and t: + return t + except Exception: + pass + try: + t = getattr(node, 'type', None) + return t if isinstance(t, str) and t else None + except Exception: + return None + + +def _relevance_key(n: DocNode) -> Tuple: + return (-(getattr(n, 'relevance_score', 0.0) or 0.0), n.uid) + + +class ContextExpansionComponent: + def __init__(self, document: Document, token_budget: int = 3000, + score_decay: float = 0.98, max_seeds: Optional[int] = None, + max_new_nodes_per_seed: int = 2): + self.document = document + self.token_budget = token_budget + self.score_decay = score_decay + self.max_seeds = max_seeds + self.max_new_nodes_per_seed = max(1, int(max_new_nodes_per_seed)) + + def _fetch_neighbors(self, node: DocNode, existing_uids: Set[str]) -> List[DocNode]: + doc_id = _get_doc_id(node) + if not doc_id: + return [] + span = (-2, 2) if (_get_node_type(node) or '').lower() == 'table' else (-1, 1) + window = None + for attempt in range(_RPC_RETRIES + 1): + try: + window = self.document.get_window_nodes(node, span=span, merge=False) + break + except Exception as e: + if attempt < _RPC_RETRIES: + time.sleep(_RPC_RETRY_DELAY) + else: + LOG.warning('[CtxExpand] RPC 全部失败 uid=%s: %s', node.uid, e) + return [] + window = window if isinstance(window, list) else ([window] if window else []) + neighbors = [ + n for n in window + if n.uid != node.uid and n.uid not in existing_uids and _get_doc_id(n) == doc_id + ] + neighbors.sort(key=_node_sort_key) + return neighbors + + def __call__(self, nodes: List[DocNode], **kwargs) -> List[DocNode]: + if not nodes: + return nodes + seeds = sorted(nodes, key=_relevance_key) + if self.max_seeds is not None and self.max_seeds > 0: + seeds = seeds[: self.max_seeds] + existing_uids: Set[str] = {n.uid for n in nodes} + all_added: List[DocNode] = [] + added_tokens = 0 + for seed in seeds: + is_table = (_get_node_type(seed) or '').lower() == 'table' + if added_tokens >= self.token_budget and not is_table: + continue + neighbors = self._fetch_neighbors(seed, existing_uids) + seed_score = getattr(seed, 'relevance_score', 0.0) or 0.0 + added_for_seed = 0 + cap = max(self.max_new_nodes_per_seed, 4) if is_table else self.max_new_nodes_per_seed + for nb in neighbors: + if added_for_seed >= cap: + break + nb_tok = _estimate_tokens(nb.text or '') + if not is_table and added_tokens + nb_tok > self.token_budget: + continue + existing_uids.add(nb.uid) + try: + nb.relevance_score = seed_score * self.score_decay + except (AttributeError, TypeError): + pass + all_added.append(nb) + if not is_table: + added_tokens += nb_tok + added_for_seed += 1 + if all_added: + LOG.info('[CtxExpand] 扩展 +%d 节点 (+%d tokens)', len(all_added), added_tokens) + result = list(nodes) + all_added + result.sort(key=_relevance_key) + return result diff --git a/algorithm/chat/modules/algo/multiturn_query_rewriter.py b/algorithm/chat/components/process/multiturn_query_rewriter.py similarity index 61% rename from algorithm/chat/modules/algo/multiturn_query_rewriter.py rename to algorithm/chat/components/process/multiturn_query_rewriter.py index b2ae863..2ccc45b 100644 --- a/algorithm/chat/modules/algo/multiturn_query_rewriter.py +++ b/algorithm/chat/components/process/multiturn_query_rewriter.py @@ -6,44 +6,8 @@ from lazyllm.components import ChatPrompter from lazyllm.components.formatter import JsonFormatter -from chat.utils.message import BaseMessage, SessionMemory - - -MULTITURN_QUERY_REWRITE_PROMPT = """ -你是“多轮对话 Query 改写器”。在检索前,将用户最后一问改写成 -【语义完整、上下文自洽、可独立理解】的一句话查询。只做改写,不回答。 - -必须遵守: -1) 遵循**保守改写**原则 - - 仅在必要时改写:指代不明、关键约束仅存在于上下文、多轮任务延续等 - - 若 last_user_query 脱离任何上下文仍语义完整,不得进行任何程度的加工和改写(名词替换、句式调整等)。 -2) 结合 chat_history 与 session_memory 解析指代与省略;继承已给出的时间/地点/来源/语言等约束。 - - 输入中提供变量 has_appendix 表示用户是否上传了附件。若 last_user_query 中存在指示代词 - (如“这是谁 / 这两个人 / 这里 / 那张表”),必须先判断指代来源是历史对话还是上传的附件;确保不把附件指代误改写为历史内容,或反之。 - - 若指代来源无法确定,则保持保守改写或不改写,不做臆测。 -3) 将“今天/近两年/上周”等相对时间,基于 current_date 归一为绝对日期或区间。 -4) 不臆造事实或新增约束;若存在歧义,做**保守改写**并下调 confidence,在 rationale_short 说明原因。 -5) 若上轮限定了信息源/文档集合,需在 rewritten_query 和 constraints.filters.source 中显式保留。 -6) 语言跟随 last_user_query;若提供 user_locale 且一致,则优先使用该语言。 -7) 仅输出一个 JSON 对象;不要包含除规定字段外的任何内容。 - -输出 JSON(严格按此结构): -{ - "rewritten_query": "<面向检索的一句话,完整可独立理解>", - "language": "zh", - "constraints": { - "must_include": [], - "filters": { - "time": { "from": null, "to": null, "points": [] }, - "source": [], - "entity": [] - }, - "exclude_terms": [] - }, - "confidence": 0.0, - "rationale_short": "<1-2句说明改写要点/歧义与处理>" -} -""" +from chat.utils.schema import BaseMessage, SessionMemory +from chat.prompts.rewrite import MULTITURN_QUERY_REWRITE_PROMPT class RewriterInput(BaseModel): diff --git a/algorithm/chat/modules/engineering/sensitive_filter.py b/algorithm/chat/components/process/sensitive_filter.py similarity index 100% rename from algorithm/chat/modules/engineering/sensitive_filter.py rename to algorithm/chat/components/process/sensitive_filter.py diff --git a/algorithm/chat/components/tmp/__init__.py b/algorithm/chat/components/tmp/__init__.py new file mode 100644 index 0000000..c01db1f --- /dev/null +++ b/algorithm/chat/components/tmp/__init__.py @@ -0,0 +1,6 @@ +from chat.components.tmp.local_models import BgeM3Embed, Qwen3Rerank + +__all__ = [ + 'BgeM3Embed', + 'Qwen3Rerank', +] diff --git a/algorithm/common/model/reranker.py b/algorithm/chat/components/tmp/local_models.py similarity index 64% rename from algorithm/common/model/reranker.py rename to algorithm/chat/components/tmp/local_models.py index 3596192..6383f75 100644 --- a/algorithm/common/model/reranker.py +++ b/algorithm/chat/components/tmp/local_models.py @@ -4,7 +4,54 @@ from lazyllm import LOG from lazyllm.tools.rag.doc_node import DocNode, MetadataMode -from lazyllm.module.llms.onlinemodule.base import LazyLLMOnlineRerankModuleBase +from lazyllm.module.llms.onlinemodule.base import LazyLLMOnlineEmbedModuleBase, LazyLLMOnlineRerankModuleBase + + +class BgeM3Embed(LazyLLMOnlineEmbedModuleBase): + NO_PROXY = True + + def __init__(self, embed_url: str = '', embed_model_name: str = 'custom', api_key: str = None, + skip_auth: bool = True, batch_size: int = 16, **kw): + super().__init__(embed_url=embed_url, api_key='' if skip_auth else (api_key or ''), + embed_model_name=embed_model_name, + skip_auth=skip_auth, batch_size=batch_size, **kw) + + def _set_embed_url(self): + pass + + def _encapsulated_data(self, input: Union[List, str], **kwargs): + model = kwargs.get('model', self._embed_model_name) + extras = {k: v for k, v in kwargs.items() if k not in ('model',)} + if isinstance(input, str): + json_data: Dict = {'inputs': input} + if model: + json_data['model'] = model + json_data.update(extras) + return json_data + text_batch = [input[i: i + self._batch_size] for i in range(0, len(input), self._batch_size)] + out = [] + for texts in text_batch: + item: Dict = {'inputs': texts} + if model: + item['model'] = model + item.update(extras) + out.append(item) + return out + + def _parse_response(self, response: Union[Dict, List], input: Union[List, str] + ) -> Union[List[float], List[List[float]], Dict]: + if isinstance(response, dict): + if 'data' in response: + return super()._parse_response(response, input) + return response + if isinstance(response, list): + if not response: + raise RuntimeError('empty embedding response') + if isinstance(input, str): + first = response[0] + return response if isinstance(first, float) else first + return response + raise RuntimeError(f'unexpected embedding response type: {type(response)!r}') class Qwen3Rerank(LazyLLMOnlineRerankModuleBase): @@ -27,7 +74,6 @@ def __init__( embed_model_name: str = 'Qwen3Reranker', embed_url: Optional[str] = None, api_key: str = 'api_key', - skip_auth: bool = False, batch_size: int = 64, truncate_text: bool = True, output_format: Optional[str] = None, @@ -41,12 +87,7 @@ def __init__( Args: task_description: 任务描述,会被拼入 system/user 区块。 """ - super().__init__( - embed_url=embed_url, - api_key=api_key, - embed_model_name=embed_model_name, - skip_auth=skip_auth, - ) + super().__init__(embed_url=embed_url, api_key=api_key, embed_model_name=embed_model_name) if not embed_url: raise ValueError('`url` 不能为空,请传入远端重排序服务地址。') @@ -125,70 +166,26 @@ def _truncate_if_needed(s: str) -> str: docs.append(self._DOCUMENT_TEMPLATE.format(doc=t_norm, suffix=self._PROMPT_SUFFIX)) return docs - def _encapsulated_data(self, query: str, texts: List[str], **kwargs: Any) -> Dict[str, Any]: + def _encapsulated_data(self, query: str, **kwargs: Any) -> Dict[str, Any]: + documents = kwargs.pop('documents', []) payload: Dict[str, Any] = { 'query': self._build_instruct(self._task_description, query), - 'documents': self._build_documents(texts), + 'documents': self._build_documents(documents), } - if kwargs: - for k, v in kwargs.items(): - if k not in ('query', 'documents'): - payload[k] = v + for k, v in kwargs.items(): + if k not in ('query', 'documents', 'top_n', 'top_k', 'topk'): + payload[k] = v return payload - def _parse_response(self, response: Any) -> List[float]: - """ - 期望输入: - {"results": [{"index": int, "relevance_score": float}, ...]} - """ + def _parse_response(self, response: Any, input=None) -> List[tuple]: + """返回 [(index, relevance_score), ...], 与 SiliconFlowRerank / ModuleReranker 协议一致。""" if not isinstance(response, dict) or 'results' not in response: LOG.warning("response missing 'results' field: %r", response) return [] results = response.get('results', []) try: - results = sorted(results, key=lambda x: x['index']) - return [float(item['relevance_score']) for item in results] + return [(item['index'], float(item['relevance_score'])) for item in results] except Exception as exc: LOG.error('Failed to parse response: %s; response=%r', exc, response) return [] - - def forward(self, nodes: List[DocNode], query: str, **kwargs: Any) -> List[DocNode]: - if not nodes: - return [] - - texts = self._get_format_content(nodes, **kwargs) - top_k = self._extract_top_k(len(texts), **kwargs) - - all_scores: List[float] = [] - for start in range(0, len(texts), self._batch_size): - batch_texts = texts[start:start + self._batch_size] - payload = self._encapsulated_data(query, batch_texts, **kwargs) - - try: - resp = self._session.post( - self._url, json=payload, headers=self._headers, timeout=self._timeout - ) - resp.raise_for_status() - scores = self._parse_response(resp.json()) - except requests.RequestException as exc: - LOG.error('HTTP request for reranking failed (this batch will be scored as 0): %s', exc) - scores = [] - - if len(scores) != len(batch_texts): - LOG.warning( - 'Returned scores count mismatches inputs: got=%d, expected=%d; padding with zeros.', - len(scores), len(batch_texts), - ) - if len(scores) < len(batch_texts): - scores += [0.0] * (len(batch_texts) - len(scores)) - else: - scores = scores[:len(batch_texts)] - - all_scores.extend(scores) - - scored_nodes: List[DocNode] = [nodes[i].with_score(all_scores[i]) for i in range(len(nodes))] - scored_nodes.sort(key=lambda n: n.relevance_score, reverse=True) - results = scored_nodes[:top_k] if top_k > 0 else scored_nodes - LOG.debug(f'Rerank use `{self._embed_model_name}` and get nodes: {results}') - return results diff --git a/algorithm/chat/components/tmp/tool_registry.py b/algorithm/chat/components/tmp/tool_registry.py new file mode 100644 index 0000000..67b7129 --- /dev/null +++ b/algorithm/chat/components/tmp/tool_registry.py @@ -0,0 +1,68 @@ +from typing import Dict, Any +from abc import ABC, abstractmethod +import os + +DOCUMENT_URL = os.getenv('LAZYLLM_DOCUMENT_URL', 'http://127.0.0.1:8525') + + +class BaseTool(ABC): + @property + @abstractmethod + def tool_schema(self) -> Dict[str, Any]: + """ + 工具 schema,描述工具的功能和参数 + + Returns: + Dict[str, Any]: 工具 schema 字典,格式为: + { + 'tool_name': { + 'description': '工具描述', + 'parameters': { + 'param_name': { + 'type': '参数类型', + 'des': '参数描述' + } + } + } + } + """ + pass + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + pass + + @property + def tool_name(self) -> str: + """返回工具名称,默认使用 schema 中的第一个 key""" + if self.tool_schema: + return list(self.tool_schema.keys())[0] + return self.__class__.__name__.lower() + + +# 工具注册表:自动收集所有 BaseTool 子类的实例 +_tool_instances: Dict[str, BaseTool] = {} +_tool_schemas: Dict[str, Dict[str, Any]] = {} + + +def register_tool(tool_name: str, tool_instance: BaseTool): + if not isinstance(tool_instance, BaseTool): + raise TypeError(f'Tool instance must be a subclass of BaseTool, got {type(tool_instance)}') + _tool_instances[tool_name] = tool_instance + _tool_schemas[tool_name] = tool_instance.tool_schema + + +def get_all_tool_schemas() -> Dict[str, Dict[str, Any]]: + return _tool_schemas.copy() + + +def get_tool_schema(tool_name: str) -> Dict[str, Any]: + if tool_name not in _tool_schemas: + raise KeyError(f'Tool {tool_name!r} not found in registry') + return _tool_schemas[tool_name] + + +def get_tool_instance(tool_name: str) -> BaseTool: + if tool_name not in _tool_instances: + raise KeyError(f'Tool {tool_name!r} not found in registry') + return _tool_instances[tool_name] diff --git a/algorithm/chat/config.py b/algorithm/chat/config.py new file mode 100644 index 0000000..78394fc --- /dev/null +++ b/algorithm/chat/config.py @@ -0,0 +1,44 @@ +import os +from typing import Dict + +from dotenv import load_dotenv + +load_dotenv() + +MOUNT_BASE_DIR: str = os.getenv('LAZYLLM_MOUNT_DIR', '/data') +SENSITIVE_WORDS_PATH: str = os.getenv('SENSITIVE_WORDS_PATH', 'data/sensitive_words.txt') + +_LAZYRAG_LLM_PRIORITY_ENV = os.getenv('LAZYRAG_LLM_PRIORITY') +LAZYRAG_LLM_PRIORITY = ( + int(_LAZYRAG_LLM_PRIORITY_ENV) + if _LAZYRAG_LLM_PRIORITY_ENV is not None and _LAZYRAG_LLM_PRIORITY_ENV.isdigit() + else 0 +) +USE_MULTIMODAL = False +LLM_TYPE_THINK = False + +MAX_CONCURRENCY = int(os.getenv('MAX_CONCURRENCY', 10)) +RAG_MODE = os.getenv('RAG_MODE', 'True').lower() == 'true' +MULTIMODAL_MODE = os.getenv('MULTIMODAL_MODE', 'True').lower() == 'true' + +SENSITIVE_FILTER_RESPONSE_TEXT = '对不起,我还没有学会回答这个问题。如果你有其他问题,我非常乐意为你提供帮助。' + +IMAGE_EXTENSIONS = ('.png', '.jpg', '.jpeg') +DEFAULT_TMP_BLOCK_TOPK = 20 + +DEFAULT_ALGO_SERVICE_URL = os.getenv('LAZYRAG_ALGO_SERVICE_URL', 'http://lazyllm-algo:8000').rstrip('/') +DEFAULT_ALGO_DATASET_NAME = os.getenv('LAZYRAG_ALGO_DATASET_NAME', 'general_algo') +DEFAULT_CHAT_DATASET = os.getenv('LAZYRAG_DEFAULT_CHAT_DATASET', 'algo') + +URL_MAP: Dict[str, str] = { + 'algo': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', + 'default': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', + 'general_algo': f'{DEFAULT_ALGO_SERVICE_URL},{DEFAULT_ALGO_DATASET_NAME}', + 'research_center': 'http://10.119.16.66:9003,research_center_0131_a', + 'quantum': 'http://10.119.16.66:9002,quantum_0131_a', + 'tyy': 'http://10.119.16.66:9007,tyy_0302', + 'cf': 'http://10.119.16.66:9005,cf_0304', + '3m': 'http://10.119.16.66:9006,threem_0303', + 'crag': 'http://10.119.16.66:9001,crag_0130_a', + 'debug': 'http://127.0.0.1:8525', +} diff --git a/algorithm/chat/modules/engineering/simple_llm.py b/algorithm/chat/modules/engineering/simple_llm.py deleted file mode 100644 index dfad92b..0000000 --- a/algorithm/chat/modules/engineering/simple_llm.py +++ /dev/null @@ -1,128 +0,0 @@ -import asyncio -from typing import Optional, Any, List -from pydantic import BaseModel, Field - -import lazyllm -from lazyllm import ModuleBase -from lazyllm.module import OnlineChatModuleBase -from lazyllm.components.prompter import PrompterBase -from lazyllm.components.formatter import FormatterBase - - -class LlmStrategy(BaseModel): - temperature: Optional[float] = Field( - 0.01, ge=0.0, le=1.0, description='采样温度,取值范围为[0.0, 1.0],默认值为0.01' - ) - max_tokens: Optional[int] = Field( - 4096, ge=1, le=16048, description='最大token数,取值范围为[1, 16048],默认值为4096' - ) - frequency_penalty: Optional[float] = Field( - 0, - ge=-2.0, - le=2.0, - description='重复惩罚,取值范围为[-2.0, 2.0],默认值为0。正值为减少产生相同token的频率。', - ) - priority: Optional[int] = Field( - 0, - description=( - '请求优先级,用于vllm调度。数值越大优先级越高,默认值为None(使用系统默认优先级)' - ), - ) - - class Config: - extra = 'allow' - - -class SimpleLlmComponent(ModuleBase): - def __init__( - self, - llm: OnlineChatModuleBase, - prompter=None, - return_trace: bool = False, - **kwargs, - ): - super().__init__(return_trace=return_trace) - self.llm = llm - - @property - def series(self): - return 'LlmComponent' - - @property - def type(self): - return 'LLM' - - def share( - self, - prompt: PrompterBase = None, - format: FormatterBase = None, - stream: Optional[bool] = None, - history: List[List[str]] = None, - copy_static_params: bool = False, - ): - self.llm = self.llm.share( - prompt=prompt, - format=format, - stream=stream, - history=history, - copy_static_params=copy_static_params, - ) - return self - - async def astream_iterator(self, input, llm, files, llm_chat_history=None, **kwargs): - if llm_chat_history is None: - llm_chat_history = [] - with lazyllm.ThreadPoolExecutor(1) as executor: - future = executor.submit( - llm, - input, - llm_chat_history=llm_chat_history, - lazyllm_files=files, - stream_output=True, - **kwargs, - ) - while True: - if value := lazyllm.FileSystemQueue().dequeue(): - yield ''.join(value) - elif future.done(): - break - else: - await asyncio.sleep(0.1) - llm = None - - def forward(self, query, files=None, stream=True, **kwargs: Any) -> Any: - try: - lazyllm.LOG.info(f'MODEL_NAME: {self.llm._model_name} GOT QUERY: {query}') - files = files[:2] if files else None - llm_chat_history = kwargs.pop('llm_chat_history', []) - - priority = kwargs.pop('priority', 0) - llm_strategy = kwargs.get('llm_strategy', LlmStrategy(priority=priority)) - if isinstance(llm_strategy, LlmStrategy): - llm_strategy = llm_strategy.model_dump() - - llm = self.llm.share() - kw = {k: v for k, v in llm_strategy.items() if v is not None} - - if stream: - response = self.astream_iterator( - input=query, - llm=llm, - files=files, - llm_chat_history=llm_chat_history, - **kw, - ) - else: - response = llm( - query, - stream_output=False, - llm_chat_history=llm_chat_history, - lazyllm_files=files, - **kw, - ) - return response - except Exception as e: - lazyllm.LOG.exception(e) - raise e - finally: - llm = None diff --git a/algorithm/chat/modules/engineering/tool_registry.py b/algorithm/chat/modules/engineering/tool_registry.py deleted file mode 100644 index 7211287..0000000 --- a/algorithm/chat/modules/engineering/tool_registry.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import List, Dict, Any -from abc import ABC, abstractmethod -from dataclasses import dataclass, field -import copy -import os -import itertools -import lazyllm -from lazyllm.tools.rag import Retriever -from chat.chat_pipelines.naive import get_ppl_search, parse_document_url -from common.model import get_runtime_model_settings - -DOCUMENT_URL = os.getenv('LAZYLLM_DOCUMENT_URL', 'http://127.0.0.1:8525') - - -class BaseTool(ABC): - """工具基类,要求所有工具类都必须定义 tool_schema 和 __call__""" - - @property - @abstractmethod - def tool_schema(self) -> Dict[str, Any]: - """ - 工具 schema,描述工具的功能和参数 - - Returns: - Dict[str, Any]: 工具 schema 字典,格式为: - { - 'tool_name': { - 'description': '工具描述', - 'parameters': { - 'param_name': { - 'type': '参数类型', - 'des': '参数描述' - } - } - } - } - """ - pass - - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """ - 执行工具调用,子类必须实现。 - - Returns: - Any: 工具执行结果 - """ - pass - - @property - def tool_name(self) -> str: - """返回工具名称,默认使用 schema 中的第一个 key""" - if self.tool_schema: - return list(self.tool_schema.keys())[0] - return self.__class__.__name__.lower() - - -# input querys original_query document_url params return: list of content -@dataclass -class KBSearchMemory: - """KBSearch 工具的内存结构""" - nodes: list = field(default_factory=list) - relevant_nodes: list = field(default_factory=list) - agg_nodes: dict = field(default_factory=dict) - - -class KBSearch(BaseTool): - """知识库检索工具类""" - - def __call__(self, querys, static_params, file_names=None): - return self.chunk_search(querys=querys, file_names=file_names, static_params=static_params) - - @property - def tool_schema(self) -> Dict[str, Any]: - """工具 schema 定义""" - return { - 'kb_search': { - 'description': 'Used to retrieve content chunks from the knowledge \ - base and extract target information points. Supports global and document-scoped search.', - 'parameters': { - 'querys': { - 'type': 'List[str]', - 'des': 'Distinct semantic queries; avoid overlap.' - }, - 'file_names': { - 'type': 'Optional[List[str]] = None', - 'des': 'Restrict search to specific documents; None means global.' - } - } - } - } - - def chunk_search( - self, - querys: List[str], - file_names: List[str] = None, - static_params: dict = None, - ) -> List[str]: - """执行文档块检索""" - if static_params is None: - static_params = {} - - original_query = static_params.get('query', '') - document_url = static_params.get('document_url', DOCUMENT_URL) - - if file_names: - file_ids = self.file_search(file_names, static_params) - else: - file_ids = None - - search_ppl = get_ppl_search(url=document_url) - - node_ids = set() - params = copy.deepcopy(static_params) - - file_names_unique = set() - unique_nodes = [] - - for query in querys: - if file_ids: - params['filters'].update({'docid': file_ids}) - nodes = search_ppl(params | {'query': query}) - for node in nodes: - if node._uid not in node_ids: - file_names_unique.add(node.global_metadata.get('file_name', '')) - node_ids.add(node._uid) - unique_nodes.append(node) - if file_ids: - original_nodes = search_ppl(params | {'query': original_query}) - for node in original_nodes: - if node._uid not in node_ids: - file_names_unique.add(node.global_metadata.get('file_name', '')) - node_ids.add(node._uid) - unique_nodes.append(node) - - nodes = [] - for _, grp in itertools.groupby(unique_nodes, key=lambda x: x.global_metadata['docid']): - grouped_nodes = list(grp) - new_node = grouped_nodes[0] - grouped_nodes = sorted(grouped_nodes, key=lambda x: x.metadata.get('index', 0)) - contents = [ - '{}\n{}'.format(node.metadata.get('title', '').strip(), node.get_text()) - for node in grouped_nodes - ] - new_node._content = '\n'.join(contents) - nodes.append(new_node) - - res = [] - for node in nodes: - filename = node.metadata.get('file_name', '') - res.append(f'file_name: {filename}\n{node.get_text()}') - - return nodes, res - - def file_search(self, file_names: List[str], static_params: dict = None, topk: int = 3) -> List[str]: - filters = static_params.get('filters', {}) if static_params else {} - document_url = static_params.get('document_url', DOCUMENT_URL) if static_params else DOCUMENT_URL - base_url, default_name = parse_document_url(document_url) - name = static_params.get('name', default_name) if static_params else default_name - settings = get_runtime_model_settings() - doc = lazyllm.Document(url=f'{base_url}/_call', name=name) - retriever = Retriever( - doc, - group_name='filename', - embed_keys=[settings.file_search_embed_key], - topk=topk, - ) - - file_ids = [] - for file_name in file_names: - nodes = retriever(file_name, filters=filters) - for node in nodes: - docid = node.global_metadata.get('docid', '') - if docid and docid not in file_ids: - file_ids.append(docid) - return file_ids - - -# 工具注册表:自动收集所有 BaseTool 子类的实例 -_tool_instances: Dict[str, BaseTool] = {} -_tool_schemas: Dict[str, Dict[str, Any]] = {} - - -def register_tool(tool_name: str, tool_instance: BaseTool): - """ - 注册工具实例 - - Args: - tool_name: 工具名称 - tool_instance: 工具实例(必须是 BaseTool 的子类) - """ - if not isinstance(tool_instance, BaseTool): - raise TypeError(f'Tool instance must be a subclass of BaseTool, got {type(tool_instance)}') - _tool_instances[tool_name] = tool_instance - _tool_schemas[tool_name] = tool_instance.tool_schema - - -def get_all_tool_schemas() -> Dict[str, Dict[str, Any]]: - """ - 获取所有已注册工具的 schema - - Returns: - Dict[str, Dict[str, Any]]: 所有工具的 schema 字典 - """ - return _tool_schemas.copy() - - -def get_tool_schema(tool_name: str) -> Dict[str, Any]: - """ - 获取指定工具的 schema - - Args: - tool_name: 工具名称 - - Returns: - Dict[str, Any]: 工具的 schema 字典 - """ - if tool_name not in _tool_schemas: - raise KeyError(f'Tool {tool_name!r} not found in registry') - return _tool_schemas[tool_name] - - -def get_tool_instance(tool_name: str) -> BaseTool: - """ - 获取指定工具的实例 - - Args: - tool_name: 工具名称 - - Returns: - BaseTool: 工具实例 - """ - if tool_name not in _tool_instances: - raise KeyError(f'Tool {tool_name!r} not found in registry') - return _tool_instances[tool_name] - - -register_tool('kb_search', KBSearch()) diff --git a/algorithm/chat/modules/engineering/workflow_utils.py b/algorithm/chat/modules/engineering/workflow_utils.py deleted file mode 100644 index 50db8e2..0000000 --- a/algorithm/chat/modules/engineering/workflow_utils.py +++ /dev/null @@ -1,90 +0,0 @@ -from dataclasses import dataclass, field -from typing import Any, List - - -@dataclass -class MiddleResults: - evaluation_result: dict = field(default_factory=dict) - raw_results: list = field(default_factory=list) - formatted_results: list = field(default_factory=list) - agg_results: dict = field(default_factory=dict) - - -@dataclass -class ToolMemory: - nodes: list = field(default_factory=list) - - -@dataclass -class ToolCall: - name: str - input: dict - output: Any - - -@dataclass -class PlanStep: - step_id: int - goal: str # 这一轮想搞清楚什么 - tool: str - status: str = 'pending' - raw_results: list = field(default_factory=list) # 原始结果类型,包含metadata,方便回溯内容来源 - formatted_results: list = field(default_factory=list) # 结构化的结果,str,用于中间推理和答案生成 - extracted_results: list = field(default_factory=list) # 过滤过的结构化结果, 对应更新raw_results - inference: str = '' - - -@dataclass -class ReasoningProcess: - tools: list[ToolCall] = field(default_factory=list) - - -@dataclass -class TaskContext: - query: str = '' - global_params: dict = field(default_factory=dict) - tool_params: dict = field(default_factory=dict) - pending_steps: List[PlanStep] = field(default_factory=list) - executed_steps: List[PlanStep] = field(default_factory=list) - middle_results: MiddleResults = field(default_factory=MiddleResults) - inferences: List[str] = field(default_factory=list) - reasoning_process_stream: List[str] = field(default_factory=list) # 推理过程,用于流式输出 - answer: str = '' - - -def tool_schema_to_string( - tool_schema: dict, - include_params: bool = True -) -> str: - lines = [] - - for tool_name, tool_info in tool_schema.items(): - lines.append(f'TOOL NAME: {tool_name}') - - # description - desc = tool_info.get('description') - if desc: - lines.append('DESCRIPTION:') - for sent in desc.split('. '): - sent = sent.strip() - if sent: - lines.append(f"- {sent.rstrip('.')}.") - - # parameters - if include_params: - params = tool_info.get('parameters', {}) - if params: - lines.append('PARAMETERS:') - for param_name, param_info in params.items(): - p_type = param_info.get('type', 'Any') - p_desc = param_info.get('des', '') - if p_desc: - lines.append( - f'- {param_name}: {p_type} — {p_desc}' - ) - else: - lines.append( - f'- {param_name}: {p_type}' - ) - - return '\n'.join(lines).strip() diff --git a/algorithm/chat/pipelines/__init__.py b/algorithm/chat/pipelines/__init__.py new file mode 100644 index 0000000..560da73 --- /dev/null +++ b/algorithm/chat/pipelines/__init__.py @@ -0,0 +1,11 @@ +# 核心流水线定义 +# 包含了agentic和naive两种模式,分别对应agentic.py和naive.py + +from chat.pipelines.agentic import get_ppl_agentic, agentic_rag +from chat.pipelines.naive import get_ppl_naive + +__all__ = [ + 'get_ppl_agentic', + 'get_ppl_naive', + 'agentic_rag', +] diff --git a/algorithm/chat/chat_pipelines/agentic.py b/algorithm/chat/pipelines/agentic.py similarity index 96% rename from algorithm/chat/chat_pipelines/agentic.py rename to algorithm/chat/pipelines/agentic.py index 3e4fcf8..4bd3d31 100644 --- a/algorithm/chat/chat_pipelines/agentic.py +++ b/algorithm/chat/pipelines/agentic.py @@ -5,7 +5,6 @@ import itertools import json import re -import os from concurrent.futures import ThreadPoolExecutor from lazyllm import LOG, bind, loop, pipeline, switch from tenacity import retry, stop_after_attempt, wait_fixed @@ -15,9 +14,7 @@ base_dir = Path(__file__).resolve().parents[2] sys.path.insert(0, str(base_dir)) -from common.model import build_model, get_runtime_model_settings - -from chat.modules.engineering.simple_llm import SimpleLlmComponent +from chat.pipelines.builders.get_models import get_automodel from chat.prompts.agentic import ( EVALUATOR_PROMPT, EXTRACTOR_PROMPT, @@ -26,17 +23,14 @@ PLANNER_PROMPT, TOOLCALL_PROMPT, ) -from chat.modules.engineering.tool_registry import ( +from chat.components.tmp.tool_registry import ( get_all_tool_schemas, get_tool_instance, get_tool_schema, ) -from chat.modules.engineering.output_parser import CustomOutputParser -from chat.modules.engineering.workflow_utils import ( - PlanStep, - TaskContext, - tool_schema_to_string, -) +from chat.components.generate.output_parser import CustomOutputParser +from chat.utils.schema import PlanStep, TaskContext +from chat.utils.helpers import tool_schema_to_string # global params and func @@ -53,16 +47,15 @@ def add_reasoning_process_stream(state: TaskContext, value: str, mode: str = 'in @functools.lru_cache(maxsize=1) def get_agentic_llms(): - settings = get_runtime_model_settings() - llm = build_model(settings.llm) + llm = get_automodel('llm') llm._prompt._set_model_configs(system='You are an intelligent assistant, \ strictly following user instructions to execute tasks.') - - llm_instruct = build_model(settings.llm_instruct) + llm_instruct = get_automodel('llm_instruct') llm_instruct._prompt._set_model_configs(system='You are a task assistant, \ and you must strictly follow the given requirements to complete the tasks.\ The output language should be the same as the input language.') - return llm, llm_instruct, SimpleLlmComponent(llm=llm_instruct) + llm_gen = get_automodel('llm_instruct', wrap_simple_llm=True) + return llm, llm_instruct, llm_gen # utils @@ -418,7 +411,9 @@ async def astream_iterator(agent, state): await asyncio.sleep(0.1) -agent = get_ppl_agentic() +@functools.lru_cache(maxsize=1) +def _get_agent(): + return get_ppl_agentic() def agentic_rag(global_params, tool_params, stream=False, **kwargs): @@ -434,6 +429,7 @@ def agentic_rag(global_params, tool_params, stream=False, **kwargs): state.global_params = global_params state.tool_params = tool_params state.middle_results.agg_results = {} + agent = _get_agent() if stream: as_iter = astream_iterator(agent, state) agg_nodes = state.middle_results.agg_results diff --git a/algorithm/chat/pipelines/builders/__init__.py b/algorithm/chat/pipelines/builders/__init__.py new file mode 100644 index 0000000..9b86995 --- /dev/null +++ b/algorithm/chat/pipelines/builders/__init__.py @@ -0,0 +1,12 @@ +from chat.pipelines.builders.get_models import get_automodel +from chat.pipelines.builders.get_retriever import get_retriever, get_remote_docment +from chat.pipelines.builders.get_ppl_search import get_ppl_search +from chat.pipelines.builders.get_ppl_generate import get_ppl_generate + +__all__ = [ + 'get_automodel', + 'get_retriever', + 'get_remote_docment', + 'get_ppl_search', + 'get_ppl_generate', +] diff --git a/algorithm/chat/pipelines/builders/get_models.py b/algorithm/chat/pipelines/builders/get_models.py new file mode 100644 index 0000000..8c4045a --- /dev/null +++ b/algorithm/chat/pipelines/builders/get_models.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import asyncio +import atexit +import functools +import hashlib +import os +import re +import shutil +import tempfile +import threading +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml +import lazyllm +from lazyllm import AutoModel, ModuleBase +from lazyllm.components.formatter import FormatterBase +from lazyllm.components.prompter import PrompterBase + +from chat.utils.load_config import get_role_config + +_RUNTIME_AUTO_MODEL_DIR = Path(tempfile.gettempdir()) / 'lazyrag-runtime-auto-model' + +_DEFAULT_LLM_KW: Dict[str, Any] = { + 'temperature': 0.01, + 'max_tokens': 4096, + 'frequency_penalty': 0, +} + +_lock = threading.RLock() +_base_models: Dict[str, Any] = {} +_wrapped_models: Dict[str, Any] = {} + + +def _cleanup_runtime_auto_model_dir() -> None: + shutil.rmtree(_RUNTIME_AUTO_MODEL_DIR, ignore_errors=True) + + +atexit.register(_cleanup_runtime_auto_model_dir) + + +class _StreamingLlmModule(ModuleBase): + def __init__(self, llm: Any, return_trace: bool = False): + super().__init__(return_trace=return_trace) + self.llm = llm + + @property + def series(self): + return 'LlmComponent' + + @property + def type(self): + return 'LLM' + + def share(self, prompt: PrompterBase = None, format: FormatterBase = None, + stream: Optional[bool] = None, history: List[List[str]] = None, + copy_static_params: bool = False): + self.llm = self.llm.share( + prompt=prompt, format=format, stream=stream, + history=history, copy_static_params=copy_static_params, + ) + return self + + async def _astream(self, text, llm, files, history, **kw): + with lazyllm.ThreadPoolExecutor(1) as executor: + future = executor.submit( + llm, text, + llm_chat_history=history, + lazyllm_files=files, + stream_output=True, + **kw, + ) + while True: + if value := lazyllm.FileSystemQueue().dequeue(): + yield ''.join(value) + elif future.done(): + break + else: + await asyncio.sleep(0.1) + + def forward(self, query, files=None, stream=True, **kwargs: Any) -> Any: + llm = None + try: + lazyllm.LOG.info(f'MODEL_NAME: {self.llm._model_name} GOT QUERY: {query}') + files = files[:2] if files else None + hist = kwargs.pop('llm_chat_history', []) + priority = kwargs.pop('priority', 0) + strat = kwargs.get('llm_strategy') + raw = {**_DEFAULT_LLM_KW, 'priority': priority} if strat is None else dict(strat) + kw = {k: v for k, v in raw.items() if v is not None} + llm = self.llm.share() + if stream: + return self._astream(query, llm, files, hist, **kw) + return llm(query, stream_output=False, llm_chat_history=hist, + lazyllm_files=files, **kw) + except Exception as e: + lazyllm.LOG.exception(e) + raise + finally: + llm = None + + +@functools.lru_cache(maxsize=64) +def _write_auto_model_config(serialized_config: str) -> str: + config = yaml.safe_load(serialized_config) + model_name = config['model'] + digest = hashlib.sha256(serialized_config.encode('utf-8')).hexdigest()[:16] + safe_name = re.sub(r'[^A-Za-z0-9_.-]+', '-', model_name).strip('-') or 'model' + _RUNTIME_AUTO_MODEL_DIR.mkdir(parents=True, exist_ok=True) + config_path = _RUNTIME_AUTO_MODEL_DIR / f'{safe_name}-{digest}.yaml' + temp_fd, temp_path = tempfile.mkstemp( + dir=_RUNTIME_AUTO_MODEL_DIR, prefix=f'.{safe_name}-{digest}-', suffix='.yaml', + ) + try: + with os.fdopen(temp_fd, 'w', encoding='utf-8') as f: + yaml.safe_dump({model_name: [config]}, f, sort_keys=False) + os.replace(temp_path, config_path) + except Exception: + try: + os.unlink(temp_path) + except FileNotFoundError: + pass + raise + return str(config_path) + + +def _build_auto_model(model_name: str, config: Dict[str, Any]): + cfg = deepcopy(config) + cfg['model'] = model_name + serialized = yaml.safe_dump(cfg, sort_keys=True) + return AutoModel(model=model_name, config=_write_auto_model_config(serialized)) + + +def get_automodel(role: str, *, wrap_simple_llm: bool = False) -> Any: + with _lock: + if role not in _base_models: + model_name, config = get_role_config(role) + _base_models[role] = _build_auto_model(model_name, config) + base = _base_models[role] + if not wrap_simple_llm: + return base + if role not in _wrapped_models: + _wrapped_models[role] = _StreamingLlmModule(llm=base) + return _wrapped_models[role] diff --git a/algorithm/chat/pipelines/builders/get_ppl_generate.py b/algorithm/chat/pipelines/builders/get_ppl_generate.py new file mode 100644 index 0000000..eceea32 --- /dev/null +++ b/algorithm/chat/pipelines/builders/get_ppl_generate.py @@ -0,0 +1,31 @@ +import lazyllm +from lazyllm import pipeline, bind +from chat.components.generate import AggregateComponent, RAGContextFormatter, CustomOutputParser +from chat.pipelines.builders import get_automodel +from chat.prompts.rag_answer import RAG_ANSWER_SYSTEM +from chat.config import LLM_TYPE_THINK + + +def _answer_llm(): + wrapped = get_automodel('llm', wrap_simple_llm=True) + inner = wrapped.llm + if getattr(inner, '_prompt', None) is not None: + inner._prompt._set_model_configs(system=RAG_ANSWER_SYSTEM) + return wrapped + + +def get_ppl_generate(stream=False): + with lazyllm.save_pipeline_result(): + with pipeline() as ppl: + ppl.aggregate = AggregateComponent() + ppl.formatter = RAGContextFormatter() | bind(query=ppl.kwargs['query'], nodes=ppl.aggregate) + ppl.answer = _answer_llm() | \ + bind(stream=stream, llm_chat_history=[], files=[], priority=1) + ppl.parser = CustomOutputParser(llm_type_think=LLM_TYPE_THINK) | bind( + stream=stream, + recall_result=ppl.input, + aggregate=ppl.aggregate, + image_files=[], + debug=ppl.kwargs['debug']) + + return ppl diff --git a/algorithm/chat/pipelines/builders/get_ppl_search.py b/algorithm/chat/pipelines/builders/get_ppl_search.py new file mode 100644 index 0000000..6392086 --- /dev/null +++ b/algorithm/chat/pipelines/builders/get_ppl_search.py @@ -0,0 +1,48 @@ +from typing import List, Any +import lazyllm +from lazyllm import pipeline, parallel, bind, ifs +from lazyllm.tools.rag import Reranker +from lazyllm.tools.rag.rank_fusion.reciprocal_rank_fusion import RRFFusion +from chat.components.process import AdaptiveKComponent, ContextExpansionComponent +from chat.pipelines.builders import get_automodel, get_retriever, get_remote_docment +from chat.utils.load_config import get_retrieval_settings + + +def _adaptive_get_token_len(n: Any) -> int: + txt = getattr(n, 'text', '') or '' + return max(1, len(txt) // 4) + + +def get_ppl_search(url: str, retriever_configs: List[dict] = None, topk=20, k_max=10): + if retriever_configs is None: + retriever_configs = get_retrieval_settings().retriever_configs + + retrieval = get_retriever(url, retriever_configs) + retrievers = retrieval.kb_retrievers + tmp_retriever = retrieval.tmp_retriever_pipeline + document = get_remote_docment(url) + + with lazyllm.save_pipeline_result(): + with pipeline() as search_ppl: + search_ppl.parse_input = lambda x: x['query'] + search_ppl.divert = ifs( + (lambda _, x: bool(x.get('files'))) | bind(x=search_ppl.input), + tpath=tmp_retriever | bind(files=search_ppl.input['files']), + fpath=parallel(*[(retriever | bind(filters=search_ppl.input['filters'])) for retriever in retrievers]) + ) + search_ppl.merge_results = lambda *args: args + search_ppl.join = RRFFusion(top_k=50) + search_ppl.reranker = Reranker('ModuleReranker', model=get_automodel('reranker'), topk=topk) | bind( + query=search_ppl.input['query'] + ) + search_ppl.adaptive_k = AdaptiveKComponent(bias=2, k_max=k_max, gap_tau=0.2, + get_token_len=_adaptive_get_token_len, + max_tokens=2048) + search_ppl.ctx_expand = ContextExpansionComponent( + document=document, + token_budget=1500, + score_decay=0.97, + max_seeds=1 + ) + + return search_ppl diff --git a/algorithm/chat/pipelines/builders/get_retriever.py b/algorithm/chat/pipelines/builders/get_retriever.py new file mode 100644 index 0000000..232304b --- /dev/null +++ b/algorithm/chat/pipelines/builders/get_retriever.py @@ -0,0 +1,41 @@ +from typing import List, NamedTuple + +from lazyllm import Retriever, bind, pipeline, Document +from lazyllm.tools.rag import TempDocRetriever + +from chat.pipelines.builders.get_models import get_automodel +from chat.utils.load_config import get_retrieval_settings +from chat.config import DEFAULT_TMP_BLOCK_TOPK + + +class SearchRetrievalParts(NamedTuple): + kb_retrievers: List[Retriever] + tmp_retriever_pipeline: object + + +def get_remote_docment(url: str) -> Document: + url = url.split(',') + if len(url) == 1: + url, name = url[0], '__default__' + else: + url, name = url[0], url[1] + return Document(url=f'{url}/_call', name=name) + + +def get_retriever(url: str, retriever_configs: List[dict], *, + tmp_block_topk: int = DEFAULT_TMP_BLOCK_TOPK + ) -> SearchRetrievalParts: + document = get_remote_docment(url) + kb_retrievers = [Retriever(document, **cfg) for cfg in retriever_configs] + + settings = get_retrieval_settings() + ref_docs_retriever = TempDocRetriever(embed=get_automodel(settings.temp_doc_embed_key)) + ref_docs_retriever.add_subretriever('block', topk=tmp_block_topk) + with pipeline() as tmp_ppl: + tmp_ppl.parse_input = lambda input, **kwargs: kwargs.get('files', []) + tmp_ppl.tmp_retriever = ref_docs_retriever | bind(query=tmp_ppl.input) + + return SearchRetrievalParts( + kb_retrievers=kb_retrievers, + tmp_retriever_pipeline=tmp_ppl, + ) diff --git a/algorithm/chat/pipelines/naive.py b/algorithm/chat/pipelines/naive.py new file mode 100644 index 0000000..e0150b0 --- /dev/null +++ b/algorithm/chat/pipelines/naive.py @@ -0,0 +1,33 @@ +from typing import List +import lazyllm +from lazyllm import pipeline, bind, ifs + +from chat.pipelines.builders import get_ppl_search, get_ppl_generate, get_automodel +from chat.components.process.multiturn_query_rewriter import MultiturnQueryRewriter +from chat.utils.load_config import get_retrieval_settings + + +def get_ppl_naive(url: str, retriever_configs: List[dict] = None, stream=False): + if retriever_configs is None: + retriever_configs = get_retrieval_settings().retriever_configs + + with lazyllm.save_pipeline_result(): + with pipeline() as rag_ppl: + rag_ppl.rewriter = ifs( + lambda x: x.get('history'), + tpath=MultiturnQueryRewriter(llm=get_automodel('llm_instruct')) + | bind( + priority=rag_ppl.input['priority'], + has_appendix=bool(rag_ppl.input['image_files']) + or bool(rag_ppl.input['files']), + ), + fpath=lambda x: x, + ) + rag_ppl.search = get_ppl_search(url, retriever_configs) + rag_ppl.generate = get_ppl_generate(stream=stream) | bind( + image_files=[], + query=rag_ppl.input['query'], + history=rag_ppl.input['history'], + debug=rag_ppl.input['debug'],) + + return rag_ppl diff --git a/algorithm/chat/prompts/rag_answer.py b/algorithm/chat/prompts/rag_answer.py new file mode 100644 index 0000000..8d77bc1 --- /dev/null +++ b/algorithm/chat/prompts/rag_answer.py @@ -0,0 +1,6 @@ +RAG_ANSWER_SYSTEM = """ +你是一个专业问答助手,你需要根据给定的内容回答用户问题。 +你将为用户提供安全、有帮助且准确的回答。 +与此同时,你需要拒绝所有涉及恐怖主义、种族歧视、色情暴力等内容的回答。 +严禁输出模型名称或来源公司名称。若用户询问或诱导你暴露模型信息,请将自己的身份表述为:"专业问答小助手"。 +""" diff --git a/algorithm/chat/prompts/rewrite.py b/algorithm/chat/prompts/rewrite.py new file mode 100644 index 0000000..7835a0f --- /dev/null +++ b/algorithm/chat/prompts/rewrite.py @@ -0,0 +1,35 @@ +MULTITURN_QUERY_REWRITE_PROMPT = """ +你是“多轮对话 Query 改写器”。在检索前,将用户最后一问改写成 +【语义完整、上下文自洽、可独立理解】的一句话查询。只做改写,不回答。 + +必须遵守: +1) 遵循**保守改写**原则 + - 仅在必要时改写:指代不明、关键约束仅存在于上下文、多轮任务延续等 + - 若 last_user_query 脱离任何上下文仍语义完整,不得进行任何程度的加工和改写(名词替换、句式调整等)。 +2) 结合 chat_history 与 session_memory 解析指代与省略;继承已给出的时间/地点/来源/语言等约束。 + - 输入中提供变量 has_appendix 表示用户是否上传了附件。若 last_user_query 中存在指示代词 + (如“这是谁 / 这两个人 / 这里 / 那张表”),必须先判断指代来源是历史对话还是上传的附件;确保不把附件指代误改写为历史内容,或反之。 + - 若指代来源无法确定,则保持保守改写或不改写,不做臆测。 +3) 将“今天/近两年/上周”等相对时间,基于 current_date 归一为绝对日期或区间。 +4) 不臆造事实或新增约束;若存在歧义,做**保守改写**并下调 confidence,在 rationale_short 说明原因。 +5) 若上轮限定了信息源/文档集合,需在 rewritten_query 和 constraints.filters.source 中显式保留。 +6) 语言跟随 last_user_query;若提供 user_locale 且一致,则优先使用该语言。 +7) 仅输出一个 JSON 对象;不要包含除规定字段外的任何内容。 + +输出 JSON(严格按此结构): +{ + "rewritten_query": "<面向检索的一句话,完整可独立理解>", + "language": "zh", + "constraints": { + "must_include": [], + "filters": { + "time": { "from": null, "to": null, "points": [] }, + "source": [], + "entity": [] + }, + "exclude_terms": [] + }, + "confidence": 0.0, + "rationale_short": "<1-2句说明改写要点/歧义与处理>" +} +""" diff --git a/algorithm/configs/runtime_models.inner.yaml b/algorithm/chat/runtime_models.inner.yaml similarity index 89% rename from algorithm/configs/runtime_models.inner.yaml rename to algorithm/chat/runtime_models.inner.yaml index 44edab7..74cac75 100644 --- a/algorithm/configs/runtime_models.inner.yaml +++ b/algorithm/chat/runtime_models.inner.yaml @@ -31,8 +31,8 @@ reranker: embeddings: # Dense vector embedding. - # Previously aligned with: bgem3_emb_dense_custom - embed_1: + # Key must match what the remote document server was indexed with. + bge_m3_dense: source: bgem3embed type: embed model: bgem3_emb_dense_custom @@ -45,8 +45,7 @@ embeddings: nlist: 128 # Sparse vector embedding. - # Previously aligned with: bgem3_emb_sparse_custom - embed_2: + bge_m3_sparse: source: bgem3embed type: embed model: bgem3_emb_sparse_custom @@ -57,8 +56,8 @@ embeddings: metric_type: IP retrieval: - temp_doc_embed_key: embed_1 - file_search_embed_key: embed_2 + temp_doc_embed_key: bge_m3_dense + file_search_embed_key: bge_m3_sparse # The default behavior already builds: # - line -> block for each enabled embed key diff --git a/algorithm/configs/runtime_models.yaml b/algorithm/chat/runtime_models.yaml similarity index 100% rename from algorithm/configs/runtime_models.yaml rename to algorithm/chat/runtime_models.yaml diff --git a/algorithm/chat/tools/sql.py b/algorithm/chat/tools/sql.py deleted file mode 100644 index 3701f59..0000000 --- a/algorithm/chat/tools/sql.py +++ /dev/null @@ -1,209 +0,0 @@ -import os -import re -import datetime -from typing import Callable - -from lazyllm import pipeline -from lazyllm.module import ModuleBase -from lazyllm.components import ChatPrompter -from lazyllm.tools.utils import chat_history_to_str -from lazyllm.tools.sql import SqlManager - -from chat.component.tools.encrypt_sql_manager import EncryptSqlManager - - -ENCRYPT_SQL_MANAGER = os.getenv('ENCRYPT_SQL_MANAGER') == 'True' -SQLManager = EncryptSqlManager if ENCRYPT_SQL_MANAGER else SqlManager - - -sql_query_instruct_template = """ -Given the following SQL tables and current date {current_date}, your job is to write sql queries in {db_type} given a user’s request. - -{schema_desc} - -Alert: Just reply the sql query in a code block start with triple-backticks and keyword "sql" -""" # noqa E501 - - -mongodb_query_instruct_template = """ -Current date is {current_date}. -You are a seasoned expert with 10 years of experience in crafting NoSQL queries for {db_type}. -I will provide a collection description in a specified format. -Your task is to analyze the user_question, which follows certain guidelines, and generate a NoSQL MongoDB aggregation pipeline accordingly. - -{schema_desc} - -Note: Please return the json pipeline in a code block start with triple-backticks and keyword "json". -""" # noqa E501 - -db_explain_instruct_template = """ -你是数据库讲解助手。请结合上下文,基于已执行 SQL 的真实结果,直接回答用户的问题,避免无关赘述与臆测。 -#【上下文】 -##聊天历史: -``` -{history} -``` - -##库表/字段说明: -``` -{schema_desc} -``` - -##已执行的 SQL: -``` -{statement} -``` - -##查询结果: -``` -{result} -``` - -#【写作要求】 -1) 语言:与用户输入保持一致(input 的语言是什么你就用什么);不要翻译或改写原始结果中的字段名与取值。 -2) 目标:先给出“结论”,再给出“依据与说明”;只引用结果中确凿可见的数据。 -3) 可读性:必要时用简短列表或表格展示关键信息;如需展示明细,最多展示前 10 行并注明总行数。 -4) 解释:点明关键筛选条件/分组/排序/时间范围等对结论的影响,避免过多 SQL 行话。 -5) 边界情况: - - 若结果为空:明确说明“未查询到匹配数据”,并给出 1–3 条可执行的追加查询建议。 - - 若存在 NULL/缺失值/单位不一致:如实标注,不做猜测。 -6) 诚信:不得编造不存在的字段或外部事实;无法从结果回答时要坦诚说明。 - -input:{user_query} -""" - - -class SqlGenSchema(ModuleBase): - def __init__( - self, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - - def forward(self, databases: list[dict], **kwargs): - database = databases[0] - source = database.get('source', {}) - kind = source.get('kind') - db_name = source.get('database') - host = source.get('host', None) - port = source.get('port', None) - user = source.get('user', None) - password = source.get('password', None) - description = database.get('description', None) - sql_manager = SQLManager( - db_type=kind, - host=host, - port=port, - user=user, - password=password, - db_name=db_name, - tables_info_dict=description - ) - return sql_manager.desc - - -class SqlGenerator(ModuleBase): - EXAMPLE_TITLE = 'Here are some example: ' - - def __init__( - self, - llm, - sql_examples: str = '', - sql_post_func: Callable = None, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - self.sql_post_func = sql_post_func - - self._pattern = re.compile(r'```sql(.+?)```', re.DOTALL) - self.example = sql_examples - - self._llm = llm - - def extract_sql_from_response(self, str_response: str) -> str: - matches = self._pattern.findall(str_response) - if matches: - extracted_content = matches[0].strip() - return extracted_content if not self.sql_post_func else self.sql_post_func(extracted_content) - else: - return '' - - def forward(self, query: str, databases: list[dict], **kwargs): - database = databases[0] - source = database.get('source', {}) - kind = source.get('kind') - schema_desc = SqlGenSchema().forward(databases) - - current_date = datetime.datetime.now().strftime('%Y-%m-%d') - sql_query_instruct = sql_query_instruct_template.format( - current_date=current_date, db_type=kind, schema_desc=schema_desc) - query_prompter = ChatPrompter(instruction=sql_query_instruct) - - with pipeline() as ppl: - ppl.llm_query = self._llm.share(prompt=query_prompter).used_by(self._module_id) - ppl.sql_extractor = self.extract_sql_from_response - return ppl(query, **kwargs) - - -class SqlExecute(ModuleBase): - EXAMPLE_TITLE = 'Here are some example: ' - - def __init__( - self, - sql_post_func: Callable = None, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - self.sql_post_func = sql_post_func - - def forward(self, statement: str, databases: list[dict], **kwargs): - database = databases[0] - source = database.get('source', {}) - kind = source.get('kind') - db_name = source.get('database') - host = source.get('host', None) - port = source.get('port', None) - user = source.get('user', None) - password = source.get('password', None) - description = database.get('description', None) - sql_manager = SQLManager( - db_type=kind, - host=host, - port=port, - user=user, - password=password, - db_name=db_name, - tables_info_dict=description - ) - return sql_manager.execute_query(statement=statement) - - -class SqlExplain(ModuleBase): - EXAMPLE_TITLE = 'Here are some example: ' - - def __init__( - self, - llm, - sql_examples: str = '', - sql_post_func: Callable = None, - return_trace: bool = False, - ) -> None: - super().__init__(return_trace=return_trace) - self.sql_post_func = sql_post_func - self.example = sql_examples - self._llm = llm - - def forward(self, input: str, user_query: str, statement: str, databases: list[dict], **kwargs): - database = databases[0] - source = database.get('source', {}) - kind = source.get('kind') - schema_desc = SqlGenSchema().forward(databases) - sql_explain_instruct = db_explain_instruct_template.format( - history=chat_history_to_str(history=kwargs.get('llm_chat_history', [])), - db_type=kind, - schema_desc=schema_desc, - statement=statement, - result=input, - user_query=user_query - ) - return self._llm(sql_explain_instruct, **kwargs) diff --git a/algorithm/chat/utils/__init__.py b/algorithm/chat/utils/__init__.py new file mode 100644 index 0000000..e2291b6 --- /dev/null +++ b/algorithm/chat/utils/__init__.py @@ -0,0 +1,24 @@ +# 工具层 +# 本层主要包含了chat流程中会用到的辅助函数和各类定义 +# schema.py - 各类pydantic数据的定义和数据类 +# config.py - 配置管理,环境变量和常量 +# helpers.py - 辅助函数(包含工具schema转换等) +# message.py - 消息数据模型(已迁移到 schema.py) +# url.py - URL处理工具 +# stream_scanner.py - 流式扫描工具 + +from chat.utils.schema import ( + BaseMessage, SessionMemory, + MiddleResults, ToolMemory, ToolCall, + PlanStep, TaskContext +) +from chat.config import URL_MAP, MAX_CONCURRENCY, LAZYRAG_LLM_PRIORITY +from chat.utils.helpers import tool_schema_to_string + +__all__ = [ + 'BaseMessage', 'SessionMemory', + 'MiddleResults', 'ToolMemory', 'ToolCall', + 'PlanStep', 'TaskContext', + 'URL_MAP', 'LAZYRAG_LLM_PRIORITY', + 'MAX_CONCURRENCY', 'tool_schema_to_string' +] diff --git a/algorithm/chat/utils/helpers.py b/algorithm/chat/utils/helpers.py new file mode 100644 index 0000000..556afa0 --- /dev/null +++ b/algorithm/chat/utils/helpers.py @@ -0,0 +1,57 @@ +import os +from pathlib import Path +from typing import List, Optional, Tuple +from fastapi import HTTPException + +from chat.config import MOUNT_BASE_DIR, IMAGE_EXTENSIONS + + +def validate_and_resolve_files(files: Optional[List[str]]) -> Tuple[List[str], List[str]]: + if not files: + return [], [] + + root = Path(MOUNT_BASE_DIR).resolve() + resolved: List[str] = [] + for f in files: + if '\x00' in f: + raise HTTPException(status_code=400, detail='Invalid path') + p = Path(f) + cand = (p if p.is_absolute() else root / p).resolve() + if not cand.is_relative_to(root): + raise HTTPException(status_code=400, detail='Path outside mount directory') + if not cand.is_file() or not os.access(cand, os.R_OK): + raise HTTPException(status_code=400, detail=f'File not accessible: {f}') + resolved.append(str(cand)) + + image_files = [p for p in resolved if p.lower().endswith(IMAGE_EXTENSIONS)] + other_files = [p for p in resolved if p not in image_files] + return other_files, image_files + + +def tool_schema_to_string( + tool_schema: dict, + include_params: bool = True +) -> str: + lines = [] + + for tool_name, tool_info in tool_schema.items(): + lines.append(f'TOOL NAME: {tool_name}') + + desc = tool_info.get('description') + if desc: + lines.append('DESCRIPTION:') + for sent in desc.split('. '): + sent = sent.strip() + if sent: + lines.append(f"- {sent.rstrip('.')}.") + + if include_params: + params = tool_info.get('parameters', {}) + if params: + lines.append('PARAMETERS:') + for name, info in params.items(): + t = info.get('type', 'Any') + d = info.get('des', '') + lines.append(f'- {name}: {t}' + (f' — {d}' if d else '')) + + return '\n'.join(lines).strip() diff --git a/algorithm/chat/utils/load_config.py b/algorithm/chat/utils/load_config.py new file mode 100644 index 0000000..418d6fc --- /dev/null +++ b/algorithm/chat/utils/load_config.py @@ -0,0 +1,191 @@ +import functools +import os +import re +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import yaml + +_CHAT_DIR = Path(__file__).resolve().parents[1] +_INNER_CONFIG_PATH = _CHAT_DIR / 'runtime_models.inner.yaml' +_EXTERNAL_CONFIG_PATH = _CHAT_DIR / 'runtime_models.yaml' +_ENV_PATTERN = re.compile(r'\$\{([^}:]+)(?::-([^}]*))?\}') + +_NON_MODEL_KEYS = frozenset({'embeddings', 'retrieval', 'roles'}) + +_DEFAULT_INDEX_KWARGS: Dict[str, Any] = { + 'index_type': 'IVF_FLAT', + 'metric_type': 'COSINE', + 'params': {'nlist': 128}, +} + + +@dataclass(frozen=True) +class RetrievalSettings: + embed_keys: List[str] + index_kwargs: List[Dict[str, Any]] + retriever_configs: List[Dict[str, Any]] + temp_doc_embed_key: str + file_search_embed_key: str + + +def _expand_env_placeholders(value: Any, config_path: str) -> Any: + if isinstance(value, dict): + return {k: _expand_env_placeholders(v, config_path) for k, v in value.items()} + if isinstance(value, list): + return [_expand_env_placeholders(item, config_path) for item in value] + if not isinstance(value, str): + return value + + def _replace(match: re.Match) -> str: + env_name = match.group(1) + default = match.group(2) + resolved = os.getenv(env_name) + if resolved is not None: + return resolved + if default is not None: + return default + raise ValueError( + f'Environment variable `{env_name}` is required by model config `{config_path}`' + ) + + expanded = _ENV_PATTERN.sub(_replace, value) + if isinstance(expanded, str): + expanded = expanded.strip() + return expanded or None + return expanded + + +def get_config_path() -> Path: + custom = os.getenv('LAZYRAG_MODEL_CONFIG_PATH') + if custom: + return Path(custom) + use_inner = os.getenv('LAZYRAG_USE_INNER_MODEL', '').lower() in ('1', 'true', 'yes') + return _INNER_CONFIG_PATH if use_inner else _EXTERNAL_CONFIG_PATH + + +def load_model_config(config_path: str | None = None) -> Dict[str, Any]: + resolved = Path(config_path) if config_path else get_config_path() + if not resolved.exists(): + raise FileNotFoundError( + f'Model config `{resolved}` not found. ' + 'Set LAZYRAG_MODEL_CONFIG_PATH or LAZYRAG_USE_INNER_MODEL=true for internal models.' + ) + with resolved.open('r', encoding='utf-8') as f: + raw = yaml.safe_load(f) or {} + if not isinstance(raw, dict): + raise ValueError(f'Model config `{resolved}` root must be a mapping.') + return _expand_env_placeholders(raw, str(resolved)) + + +_config_cache: Dict[str, Any] | None = None + + +def _get_cached_config() -> Dict[str, Any]: + global _config_cache + if _config_cache is None: + _config_cache = load_model_config() + return _config_cache + + +def _get_roles(cfg: Dict[str, Any]) -> Dict[str, Any]: + return cfg.get('roles', cfg) + + +def get_role_config(role: str) -> Tuple[str, Dict[str, Any]]: + cfg = _get_cached_config() + roles = _get_roles(cfg) + + if role in roles and role not in _NON_MODEL_KEYS: + entry = roles[role] + elif isinstance(roles.get('embeddings'), dict) and role in roles['embeddings']: + entry = roles['embeddings'][role] + else: + available = [k for k in roles if k not in _NON_MODEL_KEYS] + embed_keys = list(roles.get('embeddings', {}).keys()) + raise KeyError( + f'Unknown model role {role!r}. ' + f'Available roles: {available}, embed keys: {embed_keys}' + ) + + if not isinstance(entry, dict): + raise ValueError(f'Model role `{role}` config must be a mapping.') + + config = deepcopy(entry) + model_name = config.pop('model', None) or config.pop('name', None) + if not model_name: + raise ValueError(f'Model role `{role}` missing `model` field.') + config.pop('index_kwargs', None) + config.pop('name', None) + return model_name, config + + +def _build_default_retriever_configs(embed_keys: List[str], topk: int = 20) -> List[Dict[str, Any]]: + configs: List[Dict[str, Any]] = [] + for ek in embed_keys: + configs.append({'group_name': 'line', 'embed_keys': [ek], 'topk': topk, 'target': 'block'}) + for ek in embed_keys: + configs.append({'group_name': 'block', 'embed_keys': [ek], 'topk': topk}) + return configs + + +def _default_file_search_embed_key(embed_keys: List[str], index_kwargs: List[Dict[str, Any]]) -> str: + for ik in index_kwargs: + if 'SPARSE' in str(ik.get('index_type', '')).upper(): + return ik['embed_key'] + return embed_keys[0] + + +@functools.lru_cache(maxsize=1) +def get_retrieval_settings(config_path: str | None = None) -> RetrievalSettings: + cfg = load_model_config(config_path) + roles = _get_roles(cfg) + embeddings = roles.get('embeddings', {}) + + embed_keys: List[str] = [] + index_kwargs: List[Dict[str, Any]] = [] + for key, entry in embeddings.items(): + if not entry or not isinstance(entry, dict): + continue + embed_keys.append(key) + ik = deepcopy(entry.get('index_kwargs')) if isinstance(entry.get('index_kwargs'), dict) \ + else deepcopy(_DEFAULT_INDEX_KWARGS) + ik['embed_key'] = key + index_kwargs.append(ik) + + if not embed_keys: + raise ValueError( + 'At least one embedding must be configured under `embeddings`.' + ) + + retrieval = cfg.get('retrieval', roles.get('retrieval', {})) or {} + + temp_doc_embed_key = retrieval.get('temp_doc_embed_key', embed_keys[0]) + if temp_doc_embed_key not in embed_keys: + raise ValueError( + f'temp_doc_embed_key `{temp_doc_embed_key}` not in active embeds: {embed_keys}' + ) + + file_search_embed_key = retrieval.get( + 'file_search_embed_key', + _default_file_search_embed_key(embed_keys, index_kwargs), + ) + if file_search_embed_key not in embed_keys: + raise ValueError( + f'file_search_embed_key `{file_search_embed_key}` not in active embeds: {embed_keys}' + ) + + retriever_configs = retrieval.get('retriever_configs') + if retriever_configs is None: + topk = int(retrieval.get('default_topk', 20)) + retriever_configs = _build_default_retriever_configs(embed_keys, topk) + + return RetrievalSettings( + embed_keys=embed_keys, + index_kwargs=index_kwargs, + retriever_configs=retriever_configs, + temp_doc_embed_key=temp_doc_embed_key, + file_search_embed_key=file_search_embed_key, + ) diff --git a/algorithm/chat/utils/message.py b/algorithm/chat/utils/message.py deleted file mode 100644 index c610aef..0000000 --- a/algorithm/chat/utils/message.py +++ /dev/null @@ -1,25 +0,0 @@ -from datetime import datetime -from typing import List, Optional, Literal -from pydantic import BaseModel, Field, ConfigDict - - -class BaseMessage(BaseModel): - """单轮对话消息""" - model_config = ConfigDict(extra='forbid') - - role: Literal['user', 'assistant', 'system'] = Field(..., description='消息角色') - content: str = Field(..., description='消息文本内容') - time: Optional[datetime] = Field( - default=None, - description='消息时间(可选;ISO8601,可含时区)' - ) - - -class SessionMemory(BaseModel): - """会话内已确定的实体/意图/约束""" - model_config = ConfigDict(extra='forbid') - - topic: Optional[str] = Field(default=None, description='当前主题/任务(可选)') - entities: List[str] = Field(default_factory=list, description='相关实体列表') - time_hints: List[str] = Field(default_factory=list, description='相对时间线索(如:近三年、2023Q4)') - source_scope: List[str] = Field(default_factory=list, description='限定信息源(如:公司年报、官方博客)') diff --git a/algorithm/chat/utils/schema.py b/algorithm/chat/utils/schema.py new file mode 100644 index 0000000..eec955f --- /dev/null +++ b/algorithm/chat/utils/schema.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, List, Optional, Literal +from pydantic import BaseModel, Field, ConfigDict + + +class BaseMessage(BaseModel): + """单轮对话消息""" + model_config = ConfigDict(extra='forbid') + + role: Literal['user', 'assistant', 'system'] = Field(..., description='消息角色') + content: str = Field(..., description='消息文本内容') + time: Optional[datetime] = Field( + default=None, + description='消息时间(可选;ISO8601,可含时区)' + ) + + +class SessionMemory(BaseModel): + """会话内已确定的实体/意图/约束""" + model_config = ConfigDict(extra='forbid') + + topic: Optional[str] = Field(default=None, description='当前主题/任务(可选)') + entities: List[str] = Field(default_factory=list, description='相关实体列表') + time_hints: List[str] = Field(default_factory=list, description='相对时间线索(如:近三年、2023Q4)') + source_scope: List[str] = Field(default_factory=list, description='限定信息源(如:公司年报、官方博客)') + + +@dataclass +class MiddleResults: + evaluation_result: dict = field(default_factory=dict) + raw_results: list = field(default_factory=list) + formatted_results: list = field(default_factory=list) + agg_results: dict = field(default_factory=dict) + + +@dataclass +class ToolMemory: + nodes: list = field(default_factory=list) + + +@dataclass +class ToolCall: + name: str + input: dict + output: Any + + +@dataclass +class PlanStep: + step_id: int + goal: str + tool: str + status: str = 'pending' + raw_results: list = field(default_factory=list) + formatted_results: list = field(default_factory=list) + extracted_results: list = field(default_factory=list) + inference: str = '' + + +@dataclass +class TaskContext: + query: str = '' + global_params: dict = field(default_factory=dict) + tool_params: dict = field(default_factory=dict) + pending_steps: List[PlanStep] = field(default_factory=list) + executed_steps: List[PlanStep] = field(default_factory=list) + middle_results: MiddleResults = field(default_factory=MiddleResults) + inferences: List[str] = field(default_factory=list) + reasoning_process_stream: List[str] = field(default_factory=list) + answer: str = '' diff --git a/algorithm/chat/utils/stream_scanner.py b/algorithm/chat/utils/stream_scanner.py index f96b4b0..62e5a17 100644 --- a/algorithm/chat/utils/stream_scanner.py +++ b/algorithm/chat/utils/stream_scanner.py @@ -10,11 +10,7 @@ from chat.utils.url import get_url_basename -__all__ = ['BasePlugin', 'CitationPlugin', 'ImagePlugin', 'IncrementalScanner'] - - IMAGE_PATTERN = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)') - # Qwen-style think delimiters (lengths 7 and 8; must stay in sync with parsers elsewhere) _THINK_OPEN = '' _THINK_CLOSE = '' diff --git a/algorithm/common/__init__.py b/algorithm/common/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/algorithm/common/model/__init__.py b/algorithm/common/model/__init__.py deleted file mode 100644 index 1282d60..0000000 --- a/algorithm/common/model/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -from .embed import BgeM3Embed -from .reranker import Qwen3Rerank -from .utils import ( - DEFAULT_EMBED_KEYS, - RuntimeModelSettings, - build_embedding_models, - build_model, - get_model, - get_runtime_model_config_path, - get_runtime_model_settings, - load_runtime_model_config, -) - -__all__ = [ - 'BgeM3Embed', - 'Qwen3Rerank', - 'DEFAULT_EMBED_KEYS', - 'RuntimeModelSettings', - 'build_embedding_models', - 'build_model', - 'get_model', - 'get_runtime_model_config_path', - 'get_runtime_model_settings', - 'load_runtime_model_config', -] diff --git a/algorithm/common/model/embed.py b/algorithm/common/model/embed.py deleted file mode 100644 index e911c15..0000000 --- a/algorithm/common/model/embed.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Dict, List, Union - -from lazyllm.module.llms.onlinemodule.base import LazyLLMOnlineEmbedModuleBase - - -class BgeM3Embed(LazyLLMOnlineEmbedModuleBase): - NO_PROXY = True - - def __init__(self, embed_url: str = '', embed_model_name: str = 'custom', api_key: str = None, - skip_auth: bool = True, batch_size: int = 16, **kw): - super().__init__(embed_url, '' if skip_auth else (api_key or ''), embed_model_name, - skip_auth=skip_auth, batch_size=batch_size, **kw) - - def _set_embed_url(self): - pass - - def _encapsulated_data(self, input: Union[List, str], **kwargs): - model = kwargs.get('model', self._embed_model_name) - extras = {k: v for k, v in kwargs.items() if k not in ('model',)} - if isinstance(input, str): - json_data: Dict = {'inputs': input} - if model: - json_data['model'] = model - json_data.update(extras) - return json_data - text_batch = [input[i: i + self._batch_size] for i in range(0, len(input), self._batch_size)] - out = [] - for texts in text_batch: - item: Dict = {'inputs': texts} - if model: - item['model'] = model - item.update(extras) - out.append(item) - return out - - def _parse_response(self, response: Union[Dict, List], input: Union[List, str] - ) -> Union[List[float], List[List[float]], Dict]: - if isinstance(response, dict): - if 'data' in response: - return super()._parse_response(response, input) - return response - if isinstance(response, list): - if not response: - raise RuntimeError('empty embedding response') - if isinstance(input, str): - first = response[0] - return response if isinstance(first, float) else first - return response - raise RuntimeError(f'unexpected embedding response type: {type(response)!r}') diff --git a/algorithm/common/model/utils.py b/algorithm/common/model/utils.py deleted file mode 100644 index deccff3..0000000 --- a/algorithm/common/model/utils.py +++ /dev/null @@ -1,351 +0,0 @@ -import atexit -import functools -import hashlib -import os -import re -import shutil -import tempfile -from copy import deepcopy -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Dict, List - -import yaml -from lazyllm import AutoModel - -DEFAULT_RUNTIME_MODEL_CONFIG_PATH = Path(__file__).resolve().parents[2] / 'configs' / 'runtime_models.yaml' -DEFAULT_EMBED_KEYS = ('embed_1', 'embed_2', 'embed_3') -DEFAULT_EMBED_INDEX_KWARGS = { - 'index_type': 'IVF_FLAT', - 'metric_type': 'COSINE', - 'params': { - 'nlist': 128, - }, -} -_ENV_PATTERN = re.compile(r'\$\{([^}:]+)(?::-([^}]*))?\}') -_RUNTIME_AUTO_MODEL_DIR = Path(tempfile.gettempdir()) / 'lazyrag-runtime-auto-model' - - -@dataclass(frozen=True) -class RuntimeModelSettings: - llm: Any - llm_instruct: Any - reranker: Any - embeddings: Dict[str, Any] - embed_keys: List[str] - index_kwargs: List[Dict[str, Any]] - retriever_configs: List[Dict[str, Any]] - temp_doc_embed_key: str - file_search_embed_key: str - - -def get_runtime_model_config_path() -> str: - return os.getenv('LAZYRAG_MODEL_CONFIG_PATH') or str(DEFAULT_RUNTIME_MODEL_CONFIG_PATH) - - -def _cleanup_runtime_auto_model_dir() -> None: - shutil.rmtree(_RUNTIME_AUTO_MODEL_DIR, ignore_errors=True) - - -atexit.register(_cleanup_runtime_auto_model_dir) - - -@functools.lru_cache(maxsize=64) -def _write_runtime_auto_model_config(serialized_config: str) -> str: - config = yaml.safe_load(serialized_config) - model_name = config['model'] - digest = hashlib.sha256(serialized_config.encode('utf-8')).hexdigest()[:16] - safe_model_name = re.sub(r'[^A-Za-z0-9_.-]+', '-', model_name).strip('-') or 'runtime-model' - _RUNTIME_AUTO_MODEL_DIR.mkdir(parents=True, exist_ok=True) - config_path = _RUNTIME_AUTO_MODEL_DIR / f'{safe_model_name}-{digest}.yaml' - temp_fd, temp_path = tempfile.mkstemp( - dir=_RUNTIME_AUTO_MODEL_DIR, - prefix=f'.{safe_model_name}-{digest}-', - suffix='.yaml', - ) - try: - with os.fdopen(temp_fd, 'w', encoding='utf-8') as file: - yaml.safe_dump({model_name: [config]}, file, sort_keys=False) - os.replace(temp_path, config_path) - except Exception: - try: - os.unlink(temp_path) - except FileNotFoundError: - pass - raise - return str(config_path) - - -def _build_inline_auto_model(model_name: str, config: Dict[str, Any]): - inline_config = deepcopy(config) - inline_config['model'] = model_name - serialized_config = yaml.safe_dump(inline_config, sort_keys=True) - return AutoModel(model=model_name, config=_write_runtime_auto_model_config(serialized_config)) - - -def _strip_optional_string(value: Any) -> Any: - if not isinstance(value, str): - return value - value = value.strip() - return value or None - - -def _expand_env_placeholders(value: Any) -> Any: - if isinstance(value, dict): - return {k: _expand_env_placeholders(v) for k, v in value.items()} - if isinstance(value, list): - return [_expand_env_placeholders(item) for item in value] - if not isinstance(value, str): - return value - - def _replace(match: re.Match) -> str: - env_name = match.group(1) - default = match.group(2) - resolved = os.getenv(env_name) - if resolved is not None: - return resolved - if default is not None: - return default - raise ValueError( - f'Environment variable `{env_name}` is required by model config ' - f'`{get_runtime_model_config_path()}`' - ) - - expanded = _ENV_PATTERN.sub(_replace, value) - return _strip_optional_string(expanded) - - -def load_runtime_model_config(config_path: str | None = None) -> Dict[str, Any]: - resolved_path = Path(config_path or get_runtime_model_config_path()) - if not resolved_path.exists(): - raise FileNotFoundError( - f'Runtime model config `{resolved_path}` not found. ' - 'Set `LAZYRAG_MODEL_CONFIG_PATH` or create the default config file.' - ) - - with resolved_path.open('r', encoding='utf-8') as file: - raw = yaml.safe_load(file) or {} - if not isinstance(raw, dict): - raise ValueError(f'Runtime model config `{resolved_path}` must be a mapping.') - return _expand_env_placeholders(raw) - - -def _get_runtime_roles(config: Dict[str, Any]) -> Dict[str, Any]: - roles = config.get('roles', config) - if not isinstance(roles, dict): - raise ValueError('Runtime model config `roles` must be a mapping.') - return roles - - -def _normalize_model_entry(name: str, entry: Dict[str, Any], expected_type: str) -> Dict[str, Any]: - if not isinstance(entry, dict): - raise ValueError(f'Model role `{name}` must be a mapping.') - - normalized = deepcopy(entry) - alias_model = _strip_optional_string(normalized.pop('name', None)) - model = _strip_optional_string(normalized.get('model')) - if model and alias_model and model != alias_model: - raise ValueError(f'Model role `{name}` cannot define both `model` and `name` with different values.') - model = model or alias_model - source = _strip_optional_string(normalized.get('source')) - type_name = _strip_optional_string(normalized.get('type')) or expected_type - api_key = _strip_optional_string(normalized.get('api_key')) - url = _strip_optional_string(normalized.get('url')) - - if not source: - raise ValueError(f'Model role `{name}` missing required field `source`.') - if not model: - raise ValueError(f'Model role `{name}` missing required field `model`.') - if type_name != expected_type: - raise ValueError( - f'Model role `{name}` has type `{type_name}`, expected `{expected_type}`.' - ) - if not api_key and not normalized.get('skip_auth'): - raise ValueError( - f'Model role `{name}` missing required field `api_key`. ' - 'Use `${ENV_NAME}` in config or set `skip_auth: true` for unauthenticated endpoints.' - ) - - normalized['model'] = model - normalized['source'] = source - normalized['type'] = expected_type - if url: - normalized['url'] = url - elif 'url' in normalized: - normalized.pop('url') - if api_key: - normalized['api_key'] = api_key - elif 'api_key' in normalized: - normalized.pop('api_key') - return normalized - - -def _normalize_index_kwargs(embed_key: str, index_kwargs: Any) -> Dict[str, Any]: - if index_kwargs is None: - normalized = deepcopy(DEFAULT_EMBED_INDEX_KWARGS) - elif isinstance(index_kwargs, dict): - normalized = deepcopy(index_kwargs) - else: - raise ValueError(f'Embedding `{embed_key}` field `index_kwargs` must be a mapping.') - - normalized['embed_key'] = embed_key - return normalized - - -def _normalize_embed_configs( - roles: Dict[str, Any] -) -> tuple[Dict[str, Dict[str, Any]], List[str], List[Dict[str, Any]]]: - embeddings = roles.get('embeddings') - if embeddings is None: - embeddings = {key: roles.get(key) for key in DEFAULT_EMBED_KEYS if roles.get(key) is not None} - if not isinstance(embeddings, dict): - raise ValueError('Runtime model config `embeddings` must be a mapping.') - - unsupported_keys = set(embeddings) - set(DEFAULT_EMBED_KEYS) - if unsupported_keys: - raise ValueError( - f'Unsupported embedding slots: {sorted(unsupported_keys)!r}. ' - f'Only {list(DEFAULT_EMBED_KEYS)!r} are allowed.' - ) - - normalized_embeddings: Dict[str, Dict[str, Any]] = {} - embed_keys: List[str] = [] - index_kwargs: List[Dict[str, Any]] = [] - for embed_key in DEFAULT_EMBED_KEYS: - entry = embeddings.get(embed_key) - if not entry: - continue - normalized = _normalize_model_entry(embed_key, entry, 'embed') - index_kwargs.append(_normalize_index_kwargs(embed_key, normalized.pop('index_kwargs', None))) - normalized_embeddings[embed_key] = normalized - embed_keys.append(embed_key) - - if not embed_keys: - raise ValueError( - 'Runtime model config must enable at least one embedding slot among ' - f'{list(DEFAULT_EMBED_KEYS)!r}.' - ) - return normalized_embeddings, embed_keys, index_kwargs - - -def _resolve_embed_key(name: str, embed_key: str, allowed_keys: List[str]) -> str: - if embed_key not in allowed_keys: - raise ValueError( - f'Config field `{name}` references unknown embed key `{embed_key}`. ' - f'Enabled keys: {allowed_keys!r}.' - ) - return embed_key - - -def _default_file_search_embed_key(embed_keys: List[str], index_kwargs: List[Dict[str, Any]]) -> str: - for item in index_kwargs: - if 'SPARSE' in str(item.get('index_type', '')).upper(): - return item['embed_key'] - return embed_keys[0] - - -def _build_default_retriever_configs(embed_keys: List[str], topk: int = 20) -> List[Dict[str, Any]]: - configs: List[Dict[str, Any]] = [] - for embed_key in embed_keys: - configs.append({ - 'group_name': 'line', - 'embed_keys': [embed_key], - 'topk': topk, - 'target': 'block', - }) - for embed_key in embed_keys: - configs.append({ - 'group_name': 'block', - 'embed_keys': [embed_key], - 'topk': topk, - }) - return configs - - -def _normalize_retriever_configs(retrieval: Dict[str, Any], embed_keys: List[str]) -> List[Dict[str, Any]]: - retriever_configs = retrieval.get('retriever_configs') - if retriever_configs is None: - topk = int(retrieval.get('default_topk', 20)) - return _build_default_retriever_configs(embed_keys=embed_keys, topk=topk) - if not isinstance(retriever_configs, list): - raise ValueError('Config field `retrieval.retriever_configs` must be a list.') - - normalized_configs: List[Dict[str, Any]] = [] - for index, config in enumerate(retriever_configs, start=1): - if not isinstance(config, dict): - raise ValueError(f'Retriever config #{index} must be a mapping.') - embed_keys_for_config = config.get('embed_keys') - if not isinstance(embed_keys_for_config, list) or not embed_keys_for_config: - raise ValueError(f'Retriever config #{index} must define a non-empty `embed_keys` list.') - for embed_key in embed_keys_for_config: - _resolve_embed_key(f'retrieval.retriever_configs[{index}].embed_keys', embed_key, embed_keys) - normalized_configs.append(deepcopy(config)) - return normalized_configs - - -@functools.lru_cache(maxsize=8) -def get_runtime_model_settings(config_path: str | None = None) -> RuntimeModelSettings: - config = load_runtime_model_config(config_path) - roles = _get_runtime_roles(config) - embeddings, embed_keys, index_kwargs = _normalize_embed_configs(roles) - - llm_config = _normalize_model_entry('llm', roles.get('llm'), 'llm') - llm_instruct_raw = roles.get('llm_instruct') or roles.get('llm') - llm_instruct_config = _normalize_model_entry('llm_instruct', llm_instruct_raw, 'llm') - reranker_config = _normalize_model_entry('reranker', roles.get('reranker'), 'rerank') - - retrieval = config.get('retrieval', roles.get('retrieval', {})) or {} - if not isinstance(retrieval, dict): - raise ValueError('Config field `retrieval` must be a mapping.') - - temp_doc_embed_key = _resolve_embed_key( - 'retrieval.temp_doc_embed_key', - retrieval.get('temp_doc_embed_key', embed_keys[0]), - embed_keys, - ) - file_search_embed_key = _resolve_embed_key( - 'retrieval.file_search_embed_key', - retrieval.get('file_search_embed_key', _default_file_search_embed_key(embed_keys, index_kwargs)), - embed_keys, - ) - retriever_configs = _normalize_retriever_configs(retrieval, embed_keys) - - return RuntimeModelSettings( - llm=llm_config, - llm_instruct=llm_instruct_config, - reranker=reranker_config, - embeddings=embeddings, - embed_keys=embed_keys, - index_kwargs=index_kwargs, - retriever_configs=retriever_configs, - temp_doc_embed_key=temp_doc_embed_key, - file_search_embed_key=file_search_embed_key, - ) - - -def build_model(model_config: Any): - if not isinstance(model_config, dict): - raise TypeError('Runtime model config must be a mapping.') - config = deepcopy(model_config) - model_name = config.pop('model') - return _build_inline_auto_model(model_name, config) - - -def build_embedding_models(settings: RuntimeModelSettings | None = None) -> Dict[str, Any]: - active_settings = settings or get_runtime_model_settings() - return { - embed_key: build_model(model_config) - for embed_key, model_config in active_settings.embeddings.items() - } - - -def get_model(model, cfg=None): - if isinstance(model, dict): - config = deepcopy(model) - model_name = config.pop('model', config.pop('name', None)) - if not model_name: - raise ValueError('Inline model config must define `model`.') - if cfg in (None, False): - return _build_inline_auto_model(model_name, config) - return AutoModel(model=model_name, config=cfg, **config) - return AutoModel(model=model, config=cfg) diff --git a/algorithm/configs/auto_model.yaml b/algorithm/configs/auto_model.yaml deleted file mode 100644 index 8bbb555..0000000 --- a/algorithm/configs/auto_model.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# 本地部署模型 -qwen3_32b_custom: - - source: openai - type: llm - url: http://10.119.29.126:25121/v1/ - name: lazyllm - skip_auth: true - -qwen3_moe_custom: - - source: openai - type: llm - url: http://10.119.21.177:8000/v1/ - name: lazyllm - skip_auth: true - -bgem3_emb_dense_custom: - - source: bgem3embed - type: embed - url: http://10.119.27.151:2269/embed - skip_auth: true - -bgem3_emb_sparse_custom: - - source: bgem3embed - type: embed - url: http://10.119.27.151:2269/sparse_embed - skip_auth: true - -bge_reasoner_openai_embed: - - source: openai - type: embed - url: http://10.119.28.105:2280/v1/ - name: BGE-reasoner - skip_auth: true - -bge_m3_openai_embed: - - source: openai - type: embed - url: http://10.119.29.245:9800/v1/ - name: lazyllm - skip_auth: true - -qwen3_reranker_custom: - - source: qwen3rerank - type: rerank - url: http://10.119.24.104:8331/v1/rerank - name: Qwen3-Reranker-8B - skip_auth: true - -# Relay embedding via LazyLLM EmbeddingDeploy.这里不需要真实本地模型权重。 -bge_m3_dense_relay: - - framework: EmbeddingDeploy - type: embed - url: http://10.119.16.203:8006/generate - trust_remote_code: false - -bge_m3_sparse_relay: - - framework: EmbeddingDeploy - type: embed - url: http://10.119.16.203:8003/generate - trust_remote_code: false - - - -# 线上API模型 -text-embedding-v4: # 模型名 - - source: qwen - api_key: xxx diff --git a/algorithm/parsing/build_document.py b/algorithm/parsing/build_document.py index ba62957..4038a7a 100644 --- a/algorithm/parsing/build_document.py +++ b/algorithm/parsing/build_document.py @@ -6,7 +6,8 @@ from lazyllm.tools.rag.parsing_service import DocumentProcessor from lazyllm.tools.rag.readers import PaddleOCRPDFReader -from common.model import build_embedding_models, get_runtime_model_settings +from chat.pipelines.builders.get_models import get_automodel +from chat.utils.load_config import get_retrieval_settings from parsing.transform import NodeParser, GeneralParser, LineSplitter ALGO_ID = 'general_algo' @@ -94,8 +95,8 @@ def _build_pdf_reader(): def build_document() -> Document: processor_url = os.getenv('LAZYRAG_DOCUMENT_PROCESSOR_URL', 'http://localhost:8000') server_port = get_algo_server_port() - settings = get_runtime_model_settings() - embed = build_embedding_models(settings) + settings = get_retrieval_settings() + embed = {k: get_automodel(k) for k in settings.embed_keys} docs = Document( dataset_path=None, diff --git a/algorithm/common/db.py b/algorithm/processor/db.py similarity index 100% rename from algorithm/common/db.py rename to algorithm/processor/db.py diff --git a/algorithm/common/env.py b/algorithm/processor/env.py similarity index 100% rename from algorithm/common/env.py rename to algorithm/processor/env.py diff --git a/algorithm/processor/server.py b/algorithm/processor/server.py index 09210fc..4938302 100644 --- a/algorithm/processor/server.py +++ b/algorithm/processor/server.py @@ -2,8 +2,8 @@ import threading from lazyllm.tools.rag.parsing_service import DocumentProcessor -from common.db import require_shared_db_config -from common.env import env_int +from processor.db import require_shared_db_config +from processor.env import env_int db_config = require_shared_db_config('DocumentProcessor') diff --git a/algorithm/processor/worker.py b/algorithm/processor/worker.py index 1082f13..5c6822e 100644 --- a/algorithm/processor/worker.py +++ b/algorithm/processor/worker.py @@ -3,8 +3,8 @@ import threading from lazyllm.tools.rag.parsing_service import DocumentProcessorWorker -from common.db import require_shared_db_config -from common.env import env_bool, env_float, env_int, env_list +from processor.db import require_shared_db_config +from processor.env import env_bool, env_float, env_int, env_list db_config = require_shared_db_config('DocumentProcessorWorker') diff --git a/backend/core/doc/doc_server.py b/backend/core/doc/doc_server.py index f88d072..dc57d8d 100644 --- a/backend/core/doc/doc_server.py +++ b/backend/core/doc/doc_server.py @@ -23,8 +23,8 @@ if path not in sys.path: sys.path.insert(0, path) -from algorithm.common.db import require_shared_db_config # noqa: E402 -from algorithm.common.env import env_int, env_float # noqa: E402 +from algorithm.processor.db import require_shared_db_config # noqa: E402 +from algorithm.processor.env import env_int, env_float # noqa: E402 from lazyllm.tools.rag.doc_service import DocServer # noqa: E402