From 41eb45d3760db3090ce0ac2ac544e8c78be5b4bd Mon Sep 17 00:00:00 2001
From: HeadBro1 <84711752+HeadBro1@users.noreply.github.com>
Date: Thu, 11 Jun 2026 18:11:36 +0800
Subject: [PATCH] Improve SSE error handling
---
chatbot.py | 146 +++++++++++++++++++++++++++++++++++++----------------
1 file changed, 103 insertions(+), 43 deletions(-)
diff --git a/chatbot.py b/chatbot.py
index f6d797b..6391521 100644
--- a/chatbot.py
+++ b/chatbot.py
@@ -1,8 +1,13 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import json
+import logging
+import os
+import re
+from collections import OrderedDict
+
from agentscope.agent import ReActAgent
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
@@ -10,24 +15,67 @@
from agentscope.tool import Toolkit, execute_python_code
from agentscope.session import JSONSession
from agentscope.message import Msg
-import os
-from collections import OrderedDict
+
+logger = logging.getLogger(__name__)
app = FastAPI(title="Simple SSE API")
app.add_middleware(
CORSMiddleware,
- allow_origins=["*"],
+ allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)
+SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")
+
+
+def build_sse(data, event=None):
+ payload = json.dumps(data, ensure_ascii=False)
+ if event:
+ return f"event: {event}\ndata: {payload}\n\n"
+ return f"data: {payload}\n\n"
+
+
+def validate_session_id(username):
+ if not SESSION_ID_RE.fullmatch(username):
+ raise HTTPException(
+ status_code=400,
+ detail="username must be 1-64 characters and contain only letters, numbers, '_' or '-'",
+ )
+
+
+@app.get("/health")
+async def health():
+ return {"status": "ok"}
+
+
@app.get("/stream")
-async def stream(username,query):
+async def stream(username: str, query: str):
+ username = username.strip()
+ query = query.strip()
+ validate_session_id(username)
+ if not query:
+ raise HTTPException(status_code=400, detail="query must not be empty")
+
+ async def error_stream(message):
+ yield build_sse({"message": f"Error: {message}"})
+
+ api_key = os.environ.get("API_KEY")
+ if not api_key:
+ return StreamingResponse(
+ error_stream("API_KEY environment variable is not configured"),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ },
+ )
+
# 会话管理
- session_mgr=JSONSession(save_dir='./sessions')
-
+ session_mgr = JSONSession(save_dir="./sessions")
+
# 智能体
toolkit = Toolkit()
toolkit.register_tool_function(execute_python_code)
@@ -36,10 +84,10 @@ async def stream(username,query):
sys_prompt="幽默的助手,爱开玩笑",
model=OpenAIChatModel(
client_args={
- 'base_url':'https://dashscope.aliyuncs.com/compatible-mode/v1',
+ "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
},
model_name="qwen-max",
- api_key=os.environ.get('API_KEY'),
+ api_key=api_key,
stream=True,
),
memory=InMemoryMemory(),
@@ -47,61 +95,73 @@ async def stream(username,query):
toolkit=toolkit,
)
agent.set_console_output_enabled(False)
-
+
# 清理早期记忆
- async def mem_chunk(agent,kwargs,output):
- while await agent.memory.size()>50:
+ async def mem_chunk(agent, kwargs, output):
+ while await agent.memory.size() > 50:
await agent.memory.delete(0)
- agent.register_instance_hook(hook_type="post_reply",hook_name="mem_trunc", hook=mem_chunk)
-
+
+ agent.register_instance_hook(hook_type="post_reply", hook_name="mem_trunc", hook=mem_chunk)
+
# 加载状态
- await session_mgr.load_session_state(session_id=username,allow_not_exist=True,agent=agent)
-
- sse_queue=asyncio.Queue()
-
- msg_dict=OrderedDict()
-
+ await session_mgr.load_session_state(session_id=username, allow_not_exist=True, agent=agent)
+
+ sse_queue = asyncio.Queue()
+
+ msg_dict = OrderedDict()
+
# 应答放入队列
- async def stream_collect(agent,kwargs):
- response=''
- for block in kwargs['msg'].get_content_blocks('text'):
- response=block['text']
+ async def stream_collect(agent, kwargs):
+ response = ""
+ for block in kwargs["msg"].get_content_blocks("text"):
+ response = block["text"]
if response:
- msg_dict[kwargs['msg'].id]=response
- await sse_queue.put('
'.join([v for k,v in msg_dict.items()]))
- agent.register_instance_hook(hook_type="pre_print",hook_name="sse",hook=stream_collect)
+ msg_dict[kwargs["msg"].id] = response
+ await sse_queue.put({
+ "event": None,
+ "data": {"message": "
".join([v for k, v in msg_dict.items()])},
+ })
+
+ agent.register_instance_hook(hook_type="pre_print", hook_name="sse", hook=stream_collect)
# 应答回复客户端
async def stream_response():
while True:
- msg=await sse_queue.get()
- if msg is None:
+ item = await sse_queue.get()
+ if item is None:
break
- chunk=json.dumps({"message": msg})
- yield 'data: %s\n\n' % chunk
-
+ yield build_sse(item["data"], event=item["event"])
+
# 异步执行agent
async def execute_agent():
- final_msg=None
+ final_msg = None
try:
- input_msg=Msg(name=username,content=query,role='user')
- final_msg=await agent(input_msg)
- except Exception as e:
- print(e)
- await session_mgr.save_session_state(session_id=username,agent=agent)
- await sse_queue.put(None)
- print(f'Username:{username} Query:{query} Answer:{final_msg}')
+ input_msg = Msg(name=username, content=query, role="user")
+ final_msg = await agent(input_msg)
+ await session_mgr.save_session_state(session_id=username, agent=agent)
+ except Exception as exc:
+ logger.exception("Agent execution failed")
+ await sse_queue.put({
+ "event": None,
+ "data": {"message": f"Error: {str(exc) or 'Agent execution failed'}"},
+ })
+ finally:
+ await sse_queue.put(None)
+ logger.info("Username:%s Query:%s Answer:%s", username, query, final_msg)
+
asyncio.create_task(execute_agent())
-
+
return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
- }
+ },
)
+
if __name__ == "__main__":
import uvicorn
- uvicorn.run(app, host="0.0.0.0", port=8000)
\ No newline at end of file
+
+ uvicorn.run(app, host="0.0.0.0", port=8000)