-
Notifications
You must be signed in to change notification settings - Fork 9
[Refractor] refract chat: 重构了当前算法chat部分的内容 #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
chenhao0205
wants to merge
15
commits into
LazyAGI:main
Choose a base branch
from
chenhao0205:ch/refractor
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 5 commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d656043
refractor
chenhao0205 22a8ec8
refractor
chenhao0205 f0f76a9
minor modify
chenhao0205 81c8d32
minor modify
chenhao0205 c3dbec3
minor modify
chenhao0205 154925e
minor modify
chenhao0205 518c1e1
minor modify
chenhao0205 bd00593
Apply suggestion from @gemini-code-assist[bot]
chenhao0205 080929f
add
chenhao0205 e0f92fd
minor modify
chenhao0205 c50c2bb
delete tools
chenhao0205 ed04fa2
lint check
chenhao0205 4eb6544
lint check
chenhao0205 ad0787f
lint check
chenhao0205 4c56768
solve conflict
chenhao0205 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
Binary file not shown.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # 路由定义 (v1/v2) | ||
| from __future__ import annotations | ||
|
|
||
| from chat.app.core.chat_server import create_app | ||
|
|
||
| __all__ = ['create_app'] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
| from fastapi import APIRouter, Body, Request | ||
| from chat.app.core.chat_service import handle_chat | ||
|
|
||
| router = APIRouter() | ||
|
|
||
|
|
||
| @router.post('/api/chat', summary='与知识库对话') | ||
| @router.post('/api/chat/stream', summary='与知识库对话') | ||
| async def chat( | ||
| query: str = Body(..., description='用户问题'), | ||
| history: Optional[List[Dict[str, Any]]] = Body( | ||
| default=None, description='历史对话(每项可含 role、content)' | ||
| ), | ||
| session_id: str = Body('session_id', description='会话 ID'), | ||
| filters: Optional[Dict[str, Any]] = Body(None, description='检索过滤条件'), | ||
| files: Optional[List[str]] = Body(None, description='上传临时文件'), | ||
| debug: Optional[bool] = Body(False, description='是否开启debug模式'), | ||
| reasoning: Optional[bool] = Body(False, description='是否开启推理'), | ||
| databases: Optional[List[Dict]] = Body([], description='关联数据库'), | ||
| dataset: Optional[str] = Body('debug', description='数据库名称'), | ||
| priority: Optional[int] = Body( | ||
| None, | ||
| description='请求优先级,用于vllm调度。数值越大优先级越高', | ||
| ), | ||
| *, | ||
| request: Request, | ||
| ): | ||
| is_stream = request.url.path.endswith('/stream') | ||
| return await handle_chat( | ||
| query=query, | ||
| history=history, | ||
| session_id=session_id, | ||
| filters=filters, | ||
| files=files, | ||
| debug=debug, | ||
| reasoning=reasoning, | ||
| databases=databases, | ||
| dataset=dataset, | ||
| priority=priority, | ||
| is_stream=is_stream, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| import os | ||
| import urllib.request | ||
|
|
||
| from fastapi import APIRouter | ||
|
|
||
| router = APIRouter() | ||
|
|
||
|
|
||
| @router.get('/health', summary='Health check') | ||
| @router.get('/api/health', summary='Health check (API path)') | ||
| async def health(): | ||
| doc_url = os.getenv('LAZYRAG_DOCUMENT_SERVER_URL', 'http://localhost:8000') | ||
| status = {'document_server_url': doc_url, 'document_server_reachable': None} | ||
| try: | ||
| req = urllib.request.Request(doc_url.rstrip('/') + '/', method='GET') | ||
| urllib.request.urlopen(req, timeout=3) | ||
| status['document_server_reachable'] = True | ||
| except Exception as e: | ||
| status['document_server_reachable'] = False | ||
| status['document_server_error'] = str(e) | ||
| return status | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from chat.app.api import create_app | ||
|
|
||
| app = create_app() | ||
|
|
||
| if __name__ == '__main__': | ||
| import argparse | ||
| import uvicorn | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument('--host', type=str, default='0.0.0.0', help='listen host') | ||
| parser.add_argument('--port', type=int, default=8046, help='listen port') | ||
| args = parser.parse_args() | ||
|
|
||
| uvicorn.run(app, host=args.host, port=args.port) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from __future__ import annotations | ||
| from fastapi import FastAPI | ||
| from lazyllm import LOG, once_wrapper | ||
|
|
||
| from chat.config import URL_MAP, SENSITIVE_WORDS_PATH | ||
| from chat.pipelines.agentic import agentic_rag | ||
| from chat.pipelines.naive import get_ppl_naive | ||
| from chat.components.process.sensitive_filter import SensitiveFilter | ||
|
|
||
|
|
||
| def create_app() -> FastAPI: | ||
| """FastAPI 应用初始化与路由挂载;pipeline 在模块导入时由 ChatServer 注册。""" | ||
| app = FastAPI( | ||
| title='LazyLLM Chat API', | ||
| description='基于知识库的对话 API 服务', | ||
| version='1.0.0', | ||
| ) | ||
| from chat.app.api import chat_routes, health_routes | ||
|
|
||
| app.include_router(health_routes.router) | ||
| app.include_router(chat_routes.router) | ||
| return app | ||
|
|
||
|
|
||
| class ChatServer: | ||
| def __init__(self): | ||
| self._on_server_start() | ||
|
|
||
| @once_wrapper | ||
| def _on_server_start(self): | ||
| try: | ||
| self.query_ppl = { | ||
| name: get_ppl_naive(url=doc_url) | ||
| for name, doc_url in URL_MAP.items() | ||
| } | ||
| self.query_ppl_stream = { | ||
| name: get_ppl_naive(url=doc_url, stream=True) | ||
| for name, doc_url in URL_MAP.items() | ||
| } | ||
| self.query_ppl_reasoning = agentic_rag | ||
| self.sensitive_filter = SensitiveFilter(SENSITIVE_WORDS_PATH) | ||
|
|
||
| if self.sensitive_filter.loaded: | ||
| LOG.info( | ||
| f'[ChatServer] [SENSITIVE_FILTER] Successfully loaded ' | ||
| f'{self.sensitive_filter.keyword_count} sensitive keywords' | ||
| ) | ||
| else: | ||
| LOG.warning('[ChatServer] [SENSITIVE_FILTER] Failed to load, filter disabled') | ||
|
|
||
| LOG.info('[ChatServer] [SERVER_START]') | ||
| except Exception as exc: | ||
| LOG.exception('[ChatServer] [SERVER_START_ERROR]') | ||
| raise exc | ||
|
|
||
|
|
||
| chat_server = ChatServer() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| from __future__ import annotations | ||
| import asyncio | ||
| import json | ||
| import time | ||
| import uuid | ||
| from typing import Any, Dict, List, Optional, Union | ||
| import lazyllm | ||
| from lazyllm import LOG | ||
| from fastapi.responses import StreamingResponse | ||
| from chat.config import (URL_MAP, RAG_MODE, MULTIMODAL_MODE, MAX_CONCURRENCY, | ||
| LAZYRAG_LLM_PRIORITY, SENSITIVE_FILTER_RESPONSE_TEXT) | ||
| from chat.utils.helpers import validate_and_resolve_files | ||
| from chat.app.core.chat_server import chat_server | ||
|
|
||
|
|
||
| rag_sem = asyncio.Semaphore(MAX_CONCURRENCY) | ||
|
|
||
| def _sse_line(payload: Dict[str, Any]) -> str: | ||
| return json.dumps(payload, ensure_ascii=False, default=str) + '\n\n' | ||
|
|
||
| def _resp(code: int, msg: str, data: Any, cost: float) -> Dict[str, Any]: | ||
| return {'code': code, 'msg': msg, 'data': data, 'cost': cost} | ||
|
|
||
| def check_sensitive_content( | ||
| query: str, session_id: str, start_time: float | ||
| ) -> Optional[Dict[str, Any]]: | ||
| if not chat_server.sensitive_filter.loaded: | ||
| return None | ||
| has_sensitive, sensitive_word = chat_server.sensitive_filter.check(query) | ||
| if has_sensitive: | ||
| cost = round(time.time() - start_time, 3) | ||
| LOG.warning( | ||
| f'[ChatServer] [SENSITIVE_FILTER_BLOCKED] [query={query[:50]}...] ' | ||
| f'[sensitive_word={sensitive_word}] [session_id={session_id}]' | ||
| ) | ||
| return _resp( | ||
| 200, | ||
| 'success', | ||
| { | ||
| 'think': None, | ||
| 'text': SENSITIVE_FILTER_RESPONSE_TEXT, | ||
| 'sources': [], | ||
| }, | ||
| cost, | ||
| ) | ||
| return None | ||
|
|
||
|
|
||
| def build_query_params(query: str, history: Optional[List[Dict[str, Any]]], | ||
| filters: Optional[Dict[str, Any]], other_files: List[str], | ||
| databases: Optional[List[Dict[str, Any]]], debug: bool, | ||
| image_files: List[str], priority: Optional[int]) -> Dict[str, Any]: | ||
| hist = [ | ||
| { | ||
| 'role': str(h.get('role', 'assistant')), | ||
| 'content': str(h.get('content', '')), | ||
| } | ||
| for h in (history or []) | ||
| if isinstance(h, dict) | ||
| ] | ||
| return { | ||
| 'query': query, 'history': hist, 'filters': filters if RAG_MODE and filters else {}, | ||
| 'files': other_files, 'image_files': image_files if MULTIMODAL_MODE and image_files else [], | ||
| 'debug': debug, 'databases': databases if RAG_MODE and databases else [], 'priority': priority, | ||
| } | ||
|
|
||
|
|
||
| def log_chat_request(query: str, session_id: str,filters: Optional[Dict[str, Any]], | ||
| other_files: List[str], databases: Optional[List[Dict[str, Any]]], | ||
| image_files: List[str], cost: float, | ||
| response: Any = None, log_type: str = 'KB_CHAT') -> None: | ||
| databases_str = json.dumps(databases, ensure_ascii=False) if databases else [] | ||
| response_str = response if response is not None else None | ||
| LOG.info( | ||
| f'[ChatServer] [{log_type}] [query={query}] [session_id={session_id}] ' | ||
| f'[filters={filters}] [files={other_files}] [image_files={image_files}] ' | ||
| f'[databases={databases_str}] [cost={cost}] [response={response_str}]' | ||
| ) | ||
|
|
||
|
|
||
| async def handle_chat(query: str, history: Optional[List[Dict[str, Any]]], | ||
| session_id: str, filters: Optional[Dict[str, Any]], | ||
| files: Optional[List[str]], debug: Optional[bool], reasoning: Optional[bool], | ||
| databases: Optional[List[Dict[str, Any]]], dataset: Optional[str], | ||
| priority: Optional[int], is_stream: bool) -> Union[Dict[str, Any], StreamingResponse]: | ||
| result = None | ||
| priority = LAZYRAG_LLM_PRIORITY if priority is None else priority | ||
|
|
||
| if dataset not in chat_server.query_ppl: | ||
| return _resp(400, f'dataset {dataset} not found', None, 0.0) | ||
|
|
||
| start_time = time.time() | ||
| sensitive_check_result = check_sensitive_content(query, session_id, start_time) | ||
| sid = f'{session_id}_{time.time()}_{uuid.uuid4().hex}' | ||
| log_tag = 'KB_CHAT_STREAM' if is_stream else 'KB_CHAT' | ||
| LOG.info(f'[ChatServer] [{log_tag}] [query={query}] [sid={sid}]') | ||
|
|
||
| if not is_stream: | ||
| if sensitive_check_result: | ||
| return sensitive_check_result | ||
|
|
||
| other_files, image_files = validate_and_resolve_files(files) | ||
| query_params = build_query_params( | ||
| query, history, filters, other_files, image_files, | ||
| debug or False, databases, priority | ||
| ) | ||
|
|
||
| try: | ||
| async with rag_sem: | ||
| lazyllm.globals._init_sid(sid=sid) | ||
| lazyllm.locals._init_sid(sid=sid) | ||
| result = await _run_sync_ppl( | ||
| bool(reasoning), dataset, query_params, query, filters, priority | ||
| ) | ||
| cost = round(time.time() - start_time, 3) | ||
| return _resp(200, 'success', result, cost) | ||
| except Exception as exc: | ||
| LOG.exception(exc) | ||
| cost = round(time.time() - start_time, 3) | ||
| return _resp(500, f'chat service failed: {exc}', None, cost) | ||
| finally: | ||
| cost = round(time.time() - start_time, 3) | ||
| log_chat_request( | ||
| query, sid, filters, other_files, image_files, databases, cost, result | ||
| ) | ||
| else: | ||
| if sensitive_check_result: | ||
|
|
||
| async def error_stream(): | ||
| yield _sse_line(sensitive_check_result) | ||
| yield _sse_line(_resp(200, 'success', {'status': 'FINISHED'}, 0.0)) | ||
|
|
||
| return StreamingResponse(error_stream(), media_type='text/event-stream') | ||
|
|
||
| first_frame_logged = False | ||
| other_files, image_files = validate_and_resolve_files(files) | ||
| collected_chunks: List[str] = [] | ||
|
|
||
| query_params = build_query_params( | ||
| query, history, filters, other_files, image_files, False, databases, priority | ||
| ) | ||
|
|
||
| stream_call = ( | ||
| (chat_server.query_ppl_reasoning, query_params, None, True) | ||
| if reasoning | ||
| else (chat_server.query_ppl_stream[dataset], query_params) | ||
| ) | ||
|
|
||
| async def event_stream(ppl, *args) -> Any: | ||
| nonlocal first_frame_logged | ||
| try: | ||
| async with rag_sem: | ||
| lazyllm.globals._init_sid(sid=sid) | ||
| lazyllm.locals._init_sid(sid=sid) | ||
| async_result = await asyncio.to_thread(ppl, *args) | ||
| async for chunk in async_result: | ||
| now = time.time() | ||
| if not first_frame_logged: | ||
| first_cost = round(now - start_time, 3) | ||
| LOG.info( | ||
| f'[ChatServer] [KB_CHAT_STREAM_FIRST_FRAME] ' | ||
| f'[query={query}] [session_id={session_id}] ' | ||
| f'[cost={first_cost}]' | ||
| ) | ||
| first_frame_logged = True | ||
|
|
||
| chunk_str = ( | ||
| chunk | ||
| if isinstance(chunk, str) | ||
| else json.dumps(chunk, ensure_ascii=False) | ||
| ) | ||
| collected_chunks.append(chunk_str) | ||
| cost = round(now - start_time, 3) | ||
| yield _sse_line(_resp(200, 'success', chunk, cost)) | ||
|
|
||
| except Exception as exc: | ||
| LOG.exception(exc) | ||
| collected_chunks.append(f'[EXCEPTION]: {str(exc)}') | ||
| final_resp = _resp( | ||
| 500, f'chat service failed: {exc}', {'status': 'FAILED'}, 0.0 | ||
| ) | ||
| else: | ||
| final_resp = _resp(200, 'success', {'status': 'FINISHED'}, 0.0) | ||
|
|
||
| cost = round(time.time() - start_time, 3) | ||
| final_resp['cost'] = cost | ||
| yield _sse_line(final_resp) | ||
|
|
||
| log_chat_request(query, sid, filters, other_files, image_files, databases, | ||
| cost, '\n'.join(collected_chunks), 'KB_CHAT_STREAM_FINISH') | ||
|
|
||
| return StreamingResponse( | ||
| event_stream(*stream_call), media_type='text/event-stream' | ||
| ) | ||
|
|
||
| async def _run_sync_ppl(reasoning: bool, dataset: str, query_params: Dict[str, Any], | ||
| query: str, filters: Optional[Dict[str, Any]], priority: Any) -> Any: | ||
| if reasoning: | ||
| return await asyncio.to_thread( | ||
| chat_server.query_ppl_reasoning, | ||
| {'query': query}, | ||
| { | ||
| 'kb_search': { | ||
| 'filters': filters, | ||
| 'files': [], | ||
| 'stream': False, | ||
| 'priority': priority, | ||
| 'document_url': URL_MAP[dataset], | ||
| } | ||
| }, | ||
| False, | ||
| ) | ||
| return await asyncio.to_thread(chat_server.query_ppl[dataset], query_params) |
File renamed without changes.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.