diff --git a/frontend/src/components/ChatWindow.tsx b/frontend/src/components/ChatWindow.tsx index df25c692..9023efb0 100644 --- a/frontend/src/components/ChatWindow.tsx +++ b/frontend/src/components/ChatWindow.tsx @@ -212,6 +212,42 @@ function groupedBlocksToAssistantContent(blocks: ContentBlock[]): string { return content; } +// ── Stream seed helpers (sessionStorage) ───────────────────────────── +// Persists streaming parts across page refreshes so the frontend can +// (a) reconnect with a visual seed, or (b) detect a missed stream completion. +const STREAM_SEED_KEY = 'chat_stream_seed_v1'; + +interface StreamSeed { + chatId: string; + parts: AGUIPart[]; + ts: number; +} + +function _saveStreamSeed(chatId: string, parts: AGUIPart[]): void { + try { + const seed: StreamSeed = { chatId, parts, ts: Date.now() }; + sessionStorage.setItem(STREAM_SEED_KEY, JSON.stringify(seed)); + } catch { /* storage may be full or unavailable */ } +} + +function _loadStreamSeed(chatId: string): StreamSeed | null { + try { + const raw = sessionStorage.getItem(STREAM_SEED_KEY); + if (!raw) return null; + const seed: StreamSeed = JSON.parse(raw); + if (seed.chatId !== chatId) return null; + if (Date.now() - seed.ts > 10 * 60 * 1000) { + sessionStorage.removeItem(STREAM_SEED_KEY); + return null; + } + return seed; + } catch { return null; } +} + +function _clearStreamSeed(): void { + try { sessionStorage.removeItem(STREAM_SEED_KEY); } catch { /* ignore */ } +} + // Drag overlay component const DragOverlay: React.FC = () => { const { t } = useI18n(); @@ -460,7 +496,6 @@ export const ChatWindow: React.FC = ({ setConfig, shouldResetNext, consumeResetFlag, - forceSaveNow, updateMessage, truncateMessagesFrom, setIsStreaming, @@ -572,12 +607,11 @@ export const ChatWindow: React.FC = ({ const { parts: streamingParts, sendMessage: sendAGUI, - resumeStream, - steerStream, stop: stopAGUIStream, stopSilently: stopAGUIStreamSilently, getParts: getStreamingParts, clearParts, + restorePartsFromSeed, resolveApproval, addApprovalDecision, consumeApprovalDecisions, @@ -599,6 +633,7 @@ export const ChatWindow: React.FC = ({ setIsStreaming(false, chatId); streamingChatIdRef.current = null; clearParts(); + _clearStreamSeed(); setCurrentUsage(null); setCurrentStreamDisplayRole('assistant'); // Reload chat from DB to reflect rolled-back state. @@ -664,6 +699,7 @@ export const ChatWindow: React.FC = ({ // stale pending-approval tool blocks from the backend race condition. setIsStreaming(false, chatId); clearParts(); + _clearStreamSeed(); // Background DB sync — delay slightly so the backend has time to commit tool // results before we reload. An immediate reload risks getting stale @@ -806,6 +842,52 @@ export const ChatWindow: React.FC = ({ return () => window.removeEventListener('agui:send-message', handler); }, [sendAGUI, currentChatId, setHeartbeatRunning, setIsStreaming]); + // Save the current streaming parts to sessionStorage on page hide so they + // can be used as a seed (or to detect a missed completion) after a refresh. + useEffect(() => { + const saveOnHide = () => { + const chatId = streamingChatIdRef.current; + if (!chatId) return; + const parts = getStreamingParts(); + if (parts.length > 0) { + _saveStreamSeed(chatId, parts); + } + }; + document.addEventListener('visibilitychange', saveOnHide); + window.addEventListener('beforeunload', saveOnHide); + return () => { + document.removeEventListener('visibilitychange', saveOnHide); + window.removeEventListener('beforeunload', saveOnHide); + }; + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [getStreamingParts]); + + // Resume a tool-approval stream via the background queue so it is + // reconnectable after a page refresh (same pattern as /chat/send). + const resumeViaQueue = useCallback(async (body: Record) => { + const chatId = (body.chat_id as string) || streamingChatIdRef.current || currentChatId; + if (!chatId) return; + + // Save current parts as seed so tryConnect's reconnect can show prior steps. + const currentParts = getStreamingParts(); + if (currentParts.length > 0) { + abandonedPartsRef.current.set(chatId, currentParts); + } + + const resp = await fetch(`${getApiBase()}/chat/send`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(body), + }); + + if (!resp.ok) { + const msg = resp.status === 409 ? 'Chat is already responding' : `Resume failed (${resp.status})`; + setStatusBar(msg, 'error', 4000); + throw new Error(msg); + } + // 202: stream registered — event bus will fire stream_started → tryConnect + }, [currentChatId, getStreamingParts, setStatusBar]); + const { handleToolApproval } = useToolApproval({ currentChatId, activeStreamingChatId, @@ -818,7 +900,7 @@ export const ChatWindow: React.FC = ({ setConfig, updateMessage, setIsStreaming, - resumeStream, + resumeStream: resumeViaQueue, resolveApproval, addApprovalDecision, consumeApprovalDecisions, @@ -1030,6 +1112,34 @@ export const ChatWindow: React.FC = ({ const chatIdAtMount = currentChatId; let cancelled = false; + // Restore stream seed saved before a page refresh so we can either: + // (a) seed an in-progress reconnect with prior visual state, or + // (b) detect that the stream completed while we were gone and reload from DB. + const savedSeed = _loadStreamSeed(chatIdAtMount); + if (savedSeed) { + _clearStreamSeed(); + if (savedSeed.parts.length > 0) { + // Pre-populate the abandoned-parts map so tryConnect below can use it + // as a visual seed — even if the snapshot hasn't arrived yet. + abandonedPartsRef.current.set(chatIdAtMount, savedSeed.parts); + } + // Mark as abandoned so the existing reload-on-return logic triggers. + abandonedStreamChatsRef.current.add(chatIdAtMount); + + // If there are pending tool approvals in the seed and no active stream, + // restore the parts directly so the approval dialog re-appears. + if (!isBusStreaming(chatIdAtMount)) { + const hasSeedApprovals = savedSeed.parts.some( + p => p.type === 'tool' && p.state === 'approval-requested' + ); + if (hasSeedApprovals) { + restorePartsFromSeed(savedSeed.parts); + setIsStreaming(true, chatIdAtMount); + streamingChatIdRef.current = chatIdAtMount; + } + } + } + // On entry: reload from DB so any stream that completed while we were away // is visible. Covers platform/heartbeat chats and regular chats whose live // stream we abandoned on a previous navigation. @@ -1097,6 +1207,7 @@ export const ChatWindow: React.FC = ({ // Stream finished cleanly — drop preserved reconnect state for this chat. streamStartByChatRef.current.delete(chatIdAtMount); abandonedPartsRef.current.delete(chatIdAtMount); + _clearStreamSeed(); if (isSocialStream) { if (richMsg.content.trim()) addMessage(richMsg, chatIdAtMount); @@ -1191,22 +1302,31 @@ export const ChatWindow: React.FC = ({ streamingChatIdRef.current = currentChatId; activeChatIdRef.current = currentChatId; stopInFlightRef.current = false; - try { - await steerStream({ - chat_id: currentChatId, - message: prompt, - config: safeConfig, - }); - } catch (error) { - console.error('Error during steer:', error); - setIsStreaming(false, currentChatId); - } finally { + + // Abort the current live connection silently, then start steer via the + // background queue (/chat/steer-send → 202 → /chat/live) so it is + // reconnectable after a page refresh. + stopAGUIStreamSilently(); + clearParts(); + + const steerChatId = currentChatId; + fetch(`${getApiBase()}/chat/steer-send`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ chat_id: steerChatId, message: prompt, config: safeConfig }), + }).then(resp => { + if (!resp.ok) { + const msg = resp.status === 409 ? 'Steer failed — chat is busy' : `Steer failed (${resp.status})`; + setStatusBar(msg, 'error', 4000); + setIsStreaming(false, steerChatId); + } + }).catch(err => { + console.error('[send] /chat/steer-send failed:', err); + setIsStreaming(false, steerChatId); + }).finally(() => { steeringRef.current = false; stopInFlightRef.current = false; - setTimeout(async () => { - try { await forceSaveNow(currentChatId); } catch { } - }, 600); - } + }); return; } @@ -1305,8 +1425,10 @@ export const ChatWindow: React.FC = ({ }); }; - // Retry handler — restores last checkpoint and re-runs the original message - const handleRetry = useCallback(async () => { + // Retry handler — restores last checkpoint and re-runs the original message. + // Uses the background queue (/chat/send) so the stream is reconnectable after + // a page refresh, matching the behaviour of the main send path. + const handleRetry = useCallback(() => { if (!currentChatId || isStreaming) return; // Strip all messages from the last user message onwards so the UI is clean @@ -1324,25 +1446,31 @@ export const ChatWindow: React.FC = ({ } const chatIdForRetry = currentChatId; + clearParts(); + streamStartByChatRef.current.set(chatIdForRetry, Date.now()); setIsStreaming(true, chatIdForRetry); streamingChatIdRef.current = chatIdForRetry; activeChatIdRef.current = chatIdForRetry; stopInFlightRef.current = false; - try { - await sendAGUI({ message: '/retry', chat_id: chatIdForRetry, config: safeConfig }); - } catch (error) { - console.error('Error during retry:', error); - if (!steeringRef.current) { + // Fire and forget — event bus stream_started will trigger tryConnect. + fetch(`${getApiBase()}/chat/send`, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ message: '/retry', chat_id: chatIdForRetry, config: safeConfig }), + }).then(resp => { + if (!resp.ok) { + const msg = resp.status === 409 ? 'Chat is already responding' : `Retry failed (${resp.status})`; + setStatusBar(msg, 'error', 4000); setIsStreaming(false, chatIdForRetry); } - } finally { + }).catch(err => { + console.error('[handleRetry] /chat/send failed:', err); + setIsStreaming(false, chatIdForRetry); + }).finally(() => { stopInFlightRef.current = false; - setTimeout(async () => { - try { await forceSaveNow(chatIdForRetry); } catch { } - }, 600); - } - }, [currentChatId, isStreaming, messages, truncateMessagesFrom, setIsStreaming, sendAGUI, safeConfig, forceSaveNow]); + }); + }, [currentChatId, isStreaming, messages, truncateMessagesFrom, setIsStreaming, clearParts, safeConfig, setStatusBar]); // Stop streaming handler const stopStreaming = async () => { diff --git a/frontend/src/hooks/useAGUI.ts b/frontend/src/hooks/useAGUI.ts index ab0464bf..fab28a76 100644 --- a/frontend/src/hooks/useAGUI.ts +++ b/frontend/src/hooks/useAGUI.ts @@ -35,6 +35,11 @@ interface UseAGUIReturn { /** Read the current parts synchronously (e.g. to snapshot before switching chats) */ getParts: () => AGUIPart[]; clearParts: () => void; + /** + * Restore saved parts directly (e.g. after a page refresh) without starting a + * new stream. Used to re-display pending tool-approval dialogs from sessionStorage. + */ + restorePartsFromSeed: (seed: AGUIPart[]) => void; /** Remove an inline A2UI surface part by surface id (e.g. after ask_question is answered) */ removeInlineSurface: (surfaceId: string) => void; /** Optimistically resolve a tool approval (instantly updates UI before backend responds) */ @@ -465,6 +470,14 @@ export function useAGUI(options: UseAGUIOptions): UseAGUIReturn { const getParts = useCallback(() => partsRef.current, []); + const restorePartsFromSeed = useCallback((seed: AGUIPart[]) => { + setParts(seed); + partsRef.current = seed; + setStatus('idle'); + const approvalCount = seed.filter(p => p.type === 'tool' && p.state === 'approval-requested').length; + setPendingApprovalCountSync(approvalCount); + }, [setPendingApprovalCountSync]); + // Optimistically update a tool part's state when user approves/denies // so buttons disappear instantly (no waiting for backend round-trip) const resolveApproval = useCallback((approvalId: string, approved: boolean) => { @@ -883,5 +896,5 @@ export function useAGUI(options: UseAGUIOptions): UseAGUIReturn { } }, [resetApprovalTracking, setPendingApprovalCountSync]); - return { parts, status, error, sendMessage, resumeStream, steerStream, stop, stopSilently, getParts, clearParts, removeInlineSurface, resolveApproval, pendingApprovalCount, addApprovalDecision, consumeApprovalDecisions }; + return { parts, status, error, sendMessage, resumeStream, steerStream, stop, stopSilently, getParts, clearParts, restorePartsFromSeed, removeInlineSurface, resolveApproval, pendingApprovalCount, addApprovalDecision, consumeApprovalDecisions }; } diff --git a/src/suzent/routes/chat_routes.py b/src/suzent/routes/chat_routes.py index a1279afe..a70800ab 100644 --- a/src/suzent/routes/chat_routes.py +++ b/src/suzent/routes/chat_routes.py @@ -219,6 +219,9 @@ async def chat_send(request: Request) -> JSONResponse: Returns 202 immediately. The response streams through /chat/live and is observable via /events/stream, enabling the frontend to watch other chats while this one processes. + + Also accepts ``resume_approvals`` to resume a paused tool-approval stream + without a direct SSE connection (reconnectable after page refresh). """ try: data = await request.json() @@ -230,10 +233,11 @@ async def chat_send(request: Request) -> JSONResponse: config = data.get("config", {}) files_list = data.get("files", []) file_mentions = data.get("file_mentions", []) + resume_approvals = data.get("resume_approvals", []) if not chat_id: return JSONResponse({"error": "chat_id is required"}, status_code=400) - if not message and not files_list: + if not message and not files_list and not resume_approvals: return JSONResponse({"error": "Empty message"}, status_code=400) from suzent.core.chat_processor import ChatProcessor from suzent.agent_manager import build_agent_config @@ -253,6 +257,7 @@ async def _run() -> None: files=files_list, file_mentions=file_mentions, config_override=config_override, + resume_approvals=resume_approvals or [], ) async for chunk in generator: await stream_queue.put(chunk) @@ -274,6 +279,64 @@ async def _run() -> None: return JSONResponse({"chat_id": chat_id}, status_code=202) +async def steer_chat_send(request: Request) -> JSONResponse: + """ + POST /chat/steer-send — Interrupt the current agent run and redirect via background queue. + + Returns 202 immediately. The steer response streams through /chat/live so the + frontend can reconnect after a page refresh, matching the behaviour of /chat/send. + """ + try: + data = await request.json() + except Exception: + return JSONResponse({"error": "Invalid JSON"}, status_code=400) + + chat_id = data.get("chat_id") + message = data.get("message", "").strip() + config = data.get("config", {}) + + if not chat_id: + return JSONResponse({"error": "chat_id is required"}, status_code=400) + if not message: + return JSONResponse({"error": "message is required"}, status_code=400) + + from suzent.core.chat_processor import ChatProcessor + from suzent.agent_manager import build_agent_config + + processor = ChatProcessor() + config_override = build_agent_config(config, require_social_tool=False) + stream_queue = try_register_background_stream(chat_id) + if stream_queue is None: + return JSONResponse({"error": "Chat is already streaming"}, status_code=409) + + async def _run() -> None: + try: + generator = processor.process_steer( + chat_id=chat_id, + user_id=CONFIG.user_id, + steer_message=message, + config_override=config_override, + ) + async for chunk in generator: + await stream_queue.put(chunk) + except Exception as exc: + logger.error(f"[steer_send] Background steer error for {chat_id}: {exc}") + error_payload = ( + f'data: {{"type":"RUN_ERROR","message":{json.dumps(str(exc))}}}\n\n' + ) + try: + await stream_queue.put(error_payload) + except Exception: + pass + finally: + await stream_queue.put(None) + + task = asyncio.create_task(_run()) + _chat_send_tasks.add(task) + task.add_done_callback(_chat_send_tasks.discard) + return JSONResponse({"chat_id": chat_id}, status_code=202) + + async def retry_chat(request: Request) -> StreamingResponse: """ Restore the last retry checkpoint for a chat and re-run the original user message. diff --git a/src/suzent/server.py b/src/suzent/server.py index fca6c8fd..c18c1c0a 100644 --- a/src/suzent/server.py +++ b/src/suzent/server.py @@ -38,6 +38,7 @@ mark_chat_read, retry_chat, steer_chat, + steer_chat_send, stop_chat, update_chat, ) @@ -681,6 +682,7 @@ async def lifespan(app): Route("/chat/live", live_stream, methods=["POST"]), Route("/chat/stop", stop_chat, methods=["POST"]), Route("/chat/steer", steer_chat, methods=["POST"]), + Route("/chat/steer-send", steer_chat_send, methods=["POST"]), Route("/chat/approve-tool", approve_tool, methods=["POST"]), Route("/chat/deactivate-tool", deactivate_tool, methods=["POST"]), Route("/chat/compact", compact_chat, methods=["POST"]), diff --git a/src/suzent/streaming.py b/src/suzent/streaming.py index 095c3a58..4f0143c0 100644 --- a/src/suzent/streaming.py +++ b/src/suzent/streaming.py @@ -417,6 +417,36 @@ async def stream_agent_responses( if chat_id: register_active_stream(chat_id, out_queue) + # --- Mid-run checkpoint helper --- + async def _save_mid_run_checkpoint(messages: list) -> None: + """Persist a partial agent state snapshot after each completed tool batch. + + Called as a fire-and-forget task so it never blocks the stream. On + disconnect the DB will have the last completed tool batch saved, letting + the next turn resume from there instead of from the start of the run. + """ + if not chat_id or not messages: + return + try: + from suzent.core.agent_serializer import serialize_state + from suzent.database import get_database as _ckpt_get_db + + _model_id = getattr(agent, "_model_id", None) + _tool_names = getattr(agent, "_tool_names", []) + + def _sync_save() -> None: + _st = serialize_state( + messages, model_id=_model_id, tool_names=_tool_names + ) + _ckpt_get_db().update_chat(chat_id, agent_state=_st) + + await asyncio.to_thread(_sync_save) + logger.debug( + f"[Streaming] Mid-run checkpoint saved ({len(messages)} messages)" + ) + except Exception as _ckpt_err: + logger.debug(f"[Streaming] Mid-run checkpoint failed: {_ckpt_err}") + # --- Background agent runner (stateless resume) --- async def _agent_runner() -> None: """Run the agent in a background task. @@ -445,6 +475,16 @@ async def _agent_runner() -> None: if current_deferred: run_kwargs["deferred_tool_results"] = current_deferred + # Per-run accumulators for mid-run checkpointing. + # _chk_resp_parts: index → complete ModelResponsePart (from PartEndEvent) + # _chk_tool_returns: ToolReturnParts collected this batch + # _chk_in_flight: tool calls awaiting their result event + # _chk_base: the message history baseline for this run iteration + _chk_resp_parts: Dict[int, Any] = {} + _chk_tool_returns: list = [] + _chk_in_flight: int = 0 + _chk_base: list = list(history or []) + last_run_result = None logger.debug("[Streaming] Calling agent.run_stream_events()...") async for event in _iter_stream_events_with_timeout( @@ -456,6 +496,48 @@ async def _agent_runner() -> None: logger.debug( f"[Streaming] Received event from agent: {type(event).__name__}" ) + + # ── Mid-run checkpoint tracking ────────────────── + _event_kind = getattr(event, "event_kind", "") + if _event_kind == "part_end": + # Collect the complete part (not a delta) so we + # can reconstruct a valid ModelResponse later. + _chk_resp_parts[event.index] = event.part + elif _event_kind == "function_tool_call": + _chk_in_flight += 1 + elif _event_kind == "function_tool_result": + from pydantic_ai.messages import ( + ToolReturnPart as _TRP, + ModelResponse as _MResp, + ModelRequest as _MReq, + ) + if isinstance(event.result, _TRP): + _chk_tool_returns.append(event.result) + _chk_in_flight = max(0, _chk_in_flight - 1) + if _chk_in_flight == 0 and _chk_tool_returns: + # All tools in this batch have completed. + # Build a proper checkpoint from accumulated parts. + _resp_parts = [ + _chk_resp_parts[i] + for i in sorted(_chk_resp_parts) + ] + _checkpoint = _chk_base + [ + _MResp(parts=_resp_parts), + _MReq(parts=list(_chk_tool_returns)), + ] + asyncio.create_task( + _save_mid_run_checkpoint(_checkpoint) + ) + # Also update partial_history so the finally + # block has the latest state on crash/cancel. + partial_history = _checkpoint + # Advance base and reset per-batch state + # so the next tool batch starts clean. + _chk_base = list(_checkpoint) + _chk_resp_parts = {} + _chk_tool_returns = [] + # ──────────────────────────────────────────────── + if isinstance(event, AgentRunResultEvent): last_run_result = event.result final_response_text = str(event.result.output) @@ -782,18 +864,52 @@ async def _drain_a2ui_events() -> None: elif msg_type == "approval": # HITL: emit as AG-UI CustomEvent with all approval info + approval_info = { + "approvalId": payload["request_id"], + "toolCallId": payload.get("tool_call_id") + or payload["request_id"], + "toolName": payload.get("tool_name", ""), + "args": payload.get("args", {}), + "chatId": chat_id, + } await _queue_custom_event( out_queue, "tool_approval_request", - { - "approvalId": payload["request_id"], - "toolCallId": payload.get("tool_call_id") - or payload["request_id"], - "toolName": payload.get("tool_name", ""), - "args": payload.get("args", {}), - "chatId": chat_id, - }, + approval_info, ) + # Persist pending approval to DB so the frontend can + # reconstruct the approval dialog after a page refresh. + if chat_id: + try: + from suzent.database import get_database as _get_db + + def _save_pending_approval(): + _db = _get_db() + _chat = _db.get_chat(chat_id) + if _chat is not None: + _cfg = dict(_chat.config or {}) + # Accumulate multiple pending approvals + existing = _cfg.get("_pending_approvals") or [] + if isinstance(existing, list): + existing = [ + a for a in existing + if a.get("toolCallId") != approval_info["toolCallId"] + ] + else: + existing = [] + existing.append({ + "approvalId": approval_info["approvalId"], + "toolCallId": approval_info["toolCallId"], + "toolName": approval_info["toolName"], + "args": approval_info["args"], + "savedAt": __import__("datetime").datetime.utcnow().isoformat(), + }) + _cfg["_pending_approvals"] = existing + _db.update_chat(chat_id, config=_cfg) + + await asyncio.to_thread(_save_pending_approval) + except Exception as _pa_err: + logger.debug(f"[Streaming] Failed to save pending_approval: {_pa_err}") elif msg_type == "tool_recovery": # HITL: emit recovered tool result with output @@ -915,6 +1031,28 @@ async def encode_worker() -> None: if not getattr(deps, "is_suspended", False): # Stream ended (not paused for approvals), drop any stale cache. pop_pending_auto_approvals(chat_id) + # Clear persisted pending approvals so the frontend doesn't + # show a stale dialog on next load. + try: + from suzent.database import get_database as _get_db2 + + async def _clear_pending_approvals_task(): + try: + def _sync_clear(): + _db2 = _get_db2() + _chat2 = _db2.get_chat(chat_id) + if _chat2 is not None: + _cfg2 = dict(_chat2.config or {}) + if "_pending_approvals" in _cfg2: + del _cfg2["_pending_approvals"] + _db2.update_chat(chat_id, config=_cfg2) + await asyncio.to_thread(_sync_clear) + except Exception: + pass + + asyncio.create_task(_clear_pending_approvals_task()) + except Exception as _cp_err: + logger.debug(f"[Streaming] Failed to clear pending_approvals: {_cp_err}") existing = stream_controls.get(chat_id) if existing is control: