From 41eb45d3760db3090ce0ac2ac544e8c78be5b4bd Mon Sep 17 00:00:00 2001 From: HeadBro1 <84711752+HeadBro1@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:11:36 +0800 Subject: [PATCH] Improve SSE error handling --- chatbot.py | 146 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 103 insertions(+), 43 deletions(-) diff --git a/chatbot.py b/chatbot.py index f6d797b..6391521 100644 --- a/chatbot.py +++ b/chatbot.py @@ -1,8 +1,13 @@ -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from fastapi.middleware.cors import CORSMiddleware import asyncio import json +import logging +import os +import re +from collections import OrderedDict + from agentscope.agent import ReActAgent from agentscope.model import OpenAIChatModel from agentscope.formatter import OpenAIChatFormatter @@ -10,24 +15,67 @@ from agentscope.tool import Toolkit, execute_python_code from agentscope.session import JSONSession from agentscope.message import Msg -import os -from collections import OrderedDict + +logger = logging.getLogger(__name__) app = FastAPI(title="Simple SSE API") app.add_middleware( CORSMiddleware, - allow_origins=["*"], + allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"], ) +SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$") + + +def build_sse(data, event=None): + payload = json.dumps(data, ensure_ascii=False) + if event: + return f"event: {event}\ndata: {payload}\n\n" + return f"data: {payload}\n\n" + + +def validate_session_id(username): + if not SESSION_ID_RE.fullmatch(username): + raise HTTPException( + status_code=400, + detail="username must be 1-64 characters and contain only letters, numbers, '_' or '-'", + ) + + +@app.get("/health") +async def health(): + return {"status": "ok"} + + @app.get("/stream") -async def stream(username,query): +async def stream(username: str, query: str): + username = username.strip() + query = query.strip() + validate_session_id(username) + if not query: + raise HTTPException(status_code=400, detail="query must not be empty") + + async def error_stream(message): + yield build_sse({"message": f"Error: {message}"}) + + api_key = os.environ.get("API_KEY") + if not api_key: + return StreamingResponse( + error_stream("API_KEY environment variable is not configured"), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + }, + ) + # 会话管理 - session_mgr=JSONSession(save_dir='./sessions') - + session_mgr = JSONSession(save_dir="./sessions") + # 智能体 toolkit = Toolkit() toolkit.register_tool_function(execute_python_code) @@ -36,10 +84,10 @@ async def stream(username,query): sys_prompt="幽默的助手,爱开玩笑", model=OpenAIChatModel( client_args={ - 'base_url':'https://dashscope.aliyuncs.com/compatible-mode/v1', + "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1", }, model_name="qwen-max", - api_key=os.environ.get('API_KEY'), + api_key=api_key, stream=True, ), memory=InMemoryMemory(), @@ -47,61 +95,73 @@ async def stream(username,query): toolkit=toolkit, ) agent.set_console_output_enabled(False) - + # 清理早期记忆 - async def mem_chunk(agent,kwargs,output): - while await agent.memory.size()>50: + async def mem_chunk(agent, kwargs, output): + while await agent.memory.size() > 50: await agent.memory.delete(0) - agent.register_instance_hook(hook_type="post_reply",hook_name="mem_trunc", hook=mem_chunk) - + + agent.register_instance_hook(hook_type="post_reply", hook_name="mem_trunc", hook=mem_chunk) + # 加载状态 - await session_mgr.load_session_state(session_id=username,allow_not_exist=True,agent=agent) - - sse_queue=asyncio.Queue() - - msg_dict=OrderedDict() - + await session_mgr.load_session_state(session_id=username, allow_not_exist=True, agent=agent) + + sse_queue = asyncio.Queue() + + msg_dict = OrderedDict() + # 应答放入队列 - async def stream_collect(agent,kwargs): - response='' - for block in kwargs['msg'].get_content_blocks('text'): - response=block['text'] + async def stream_collect(agent, kwargs): + response = "" + for block in kwargs["msg"].get_content_blocks("text"): + response = block["text"] if response: - msg_dict[kwargs['msg'].id]=response - await sse_queue.put('
'.join([v for k,v in msg_dict.items()])) - agent.register_instance_hook(hook_type="pre_print",hook_name="sse",hook=stream_collect) + msg_dict[kwargs["msg"].id] = response + await sse_queue.put({ + "event": None, + "data": {"message": "
".join([v for k, v in msg_dict.items()])}, + }) + + agent.register_instance_hook(hook_type="pre_print", hook_name="sse", hook=stream_collect) # 应答回复客户端 async def stream_response(): while True: - msg=await sse_queue.get() - if msg is None: + item = await sse_queue.get() + if item is None: break - chunk=json.dumps({"message": msg}) - yield 'data: %s\n\n' % chunk - + yield build_sse(item["data"], event=item["event"]) + # 异步执行agent async def execute_agent(): - final_msg=None + final_msg = None try: - input_msg=Msg(name=username,content=query,role='user') - final_msg=await agent(input_msg) - except Exception as e: - print(e) - await session_mgr.save_session_state(session_id=username,agent=agent) - await sse_queue.put(None) - print(f'Username:{username} Query:{query} Answer:{final_msg}') + input_msg = Msg(name=username, content=query, role="user") + final_msg = await agent(input_msg) + await session_mgr.save_session_state(session_id=username, agent=agent) + except Exception as exc: + logger.exception("Agent execution failed") + await sse_queue.put({ + "event": None, + "data": {"message": f"Error: {str(exc) or 'Agent execution failed'}"}, + }) + finally: + await sse_queue.put(None) + logger.info("Username:%s Query:%s Answer:%s", username, query, final_msg) + asyncio.create_task(execute_agent()) - + return StreamingResponse( stream_response(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", - } + }, ) + if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file + + uvicorn.run(app, host="0.0.0.0", port=8000)