Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 103 additions & 43 deletions chatbot.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,81 @@
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import asyncio
import json
import logging
import os
import re
from collections import OrderedDict

from agentscope.agent import ReActAgent
from agentscope.model import OpenAIChatModel
from agentscope.formatter import OpenAIChatFormatter
from agentscope.memory import InMemoryMemory
from agentscope.tool import Toolkit, execute_python_code
from agentscope.session import JSONSession
from agentscope.message import Msg
import os
from collections import OrderedDict

logger = logging.getLogger(__name__)

app = FastAPI(title="Simple SSE API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)

SESSION_ID_RE = re.compile(r"^[A-Za-z0-9_-]{1,64}$")


def build_sse(data, event=None):
payload = json.dumps(data, ensure_ascii=False)
if event:
return f"event: {event}\ndata: {payload}\n\n"
return f"data: {payload}\n\n"


def validate_session_id(username):
if not SESSION_ID_RE.fullmatch(username):
raise HTTPException(
status_code=400,
detail="username must be 1-64 characters and contain only letters, numbers, '_' or '-'",
)


@app.get("/health")
async def health():
return {"status": "ok"}


@app.get("/stream")
async def stream(username,query):
async def stream(username: str, query: str):
username = username.strip()
query = query.strip()
validate_session_id(username)
if not query:
raise HTTPException(status_code=400, detail="query must not be empty")

async def error_stream(message):
yield build_sse({"message": f"Error: {message}"})

api_key = os.environ.get("API_KEY")
if not api_key:
return StreamingResponse(
error_stream("API_KEY environment variable is not configured"),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)

# 会话管理
session_mgr=JSONSession(save_dir='./sessions')
session_mgr = JSONSession(save_dir="./sessions")

# 智能体
toolkit = Toolkit()
toolkit.register_tool_function(execute_python_code)
Expand All @@ -36,72 +84,84 @@ async def stream(username,query):
sys_prompt="幽默的助手,爱开玩笑",
model=OpenAIChatModel(
client_args={
'base_url':'https://dashscope.aliyuncs.com/compatible-mode/v1',
"base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
},
model_name="qwen-max",
api_key=os.environ.get('API_KEY'),
api_key=api_key,
stream=True,
),
memory=InMemoryMemory(),
formatter=OpenAIChatFormatter(),
toolkit=toolkit,
)
agent.set_console_output_enabled(False)

# 清理早期记忆
async def mem_chunk(agent,kwargs,output):
while await agent.memory.size()>50:
async def mem_chunk(agent, kwargs, output):
while await agent.memory.size() > 50:
await agent.memory.delete(0)
agent.register_instance_hook(hook_type="post_reply",hook_name="mem_trunc", hook=mem_chunk)


agent.register_instance_hook(hook_type="post_reply", hook_name="mem_trunc", hook=mem_chunk)

# 加载状态
await session_mgr.load_session_state(session_id=username,allow_not_exist=True,agent=agent)
sse_queue=asyncio.Queue()
msg_dict=OrderedDict()
await session_mgr.load_session_state(session_id=username, allow_not_exist=True, agent=agent)

sse_queue = asyncio.Queue()

msg_dict = OrderedDict()

# 应答放入队列
async def stream_collect(agent,kwargs):
response=''
for block in kwargs['msg'].get_content_blocks('text'):
response=block['text']
async def stream_collect(agent, kwargs):
response = ""
for block in kwargs["msg"].get_content_blocks("text"):
response = block["text"]
if response:
msg_dict[kwargs['msg'].id]=response
await sse_queue.put('<br>'.join([v for k,v in msg_dict.items()]))
agent.register_instance_hook(hook_type="pre_print",hook_name="sse",hook=stream_collect)
msg_dict[kwargs["msg"].id] = response
await sse_queue.put({
"event": None,
"data": {"message": "<br>".join([v for k, v in msg_dict.items()])},
})

agent.register_instance_hook(hook_type="pre_print", hook_name="sse", hook=stream_collect)

# 应答回复客户端
async def stream_response():
while True:
msg=await sse_queue.get()
if msg is None:
item = await sse_queue.get()
if item is None:
break
chunk=json.dumps({"message": msg})
yield 'data: %s\n\n' % chunk

yield build_sse(item["data"], event=item["event"])

# 异步执行agent
async def execute_agent():
final_msg=None
final_msg = None
try:
input_msg=Msg(name=username,content=query,role='user')
final_msg=await agent(input_msg)
except Exception as e:
print(e)
await session_mgr.save_session_state(session_id=username,agent=agent)
await sse_queue.put(None)
print(f'Username:{username} Query:{query} Answer:{final_msg}')
input_msg = Msg(name=username, content=query, role="user")
final_msg = await agent(input_msg)
await session_mgr.save_session_state(session_id=username, agent=agent)
except Exception as exc:
logger.exception("Agent execution failed")
await sse_queue.put({
"event": None,
"data": {"message": f"Error: {str(exc) or 'Agent execution failed'}"},
})
finally:
await sse_queue.put(None)
logger.info("Username:%s Query:%s Answer:%s", username, query, final_msg)

asyncio.create_task(execute_agent())

return StreamingResponse(
stream_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
},
)


if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

uvicorn.run(app, host="0.0.0.0", port=8000)