From 7559a757c82782ad608b14605b41e65f8d0fadec Mon Sep 17 00:00:00 2001 From: Praneeth Perumalla Date: Thu, 11 Jun 2026 17:07:28 +0530 Subject: [PATCH 1/2] fix: recover streaming responses after network interruption --- backend/app.py | 11 ++ backend/models/schemas.py | 1 + backend/routes/chat.py | 180 ++++++++++++++++-- backend/tests/test_streaming_recovery.py | 231 +++++++++++++++++++++++ frontend/src/App.jsx | 2 +- frontend/src/utils/api.js | 82 ++++++-- 6 files changed, 470 insertions(+), 37 deletions(-) create mode 100644 backend/tests/test_streaming_recovery.py diff --git a/backend/app.py b/backend/app.py index f7109ed..f8846e6 100644 --- a/backend/app.py +++ b/backend/app.py @@ -31,6 +31,8 @@ FRONTEND_DIST = Path(os.getenv("FRONTEND_DIST", "/app/frontend/dist")) +import asyncio + @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting LocalMind v2.0...") @@ -38,9 +40,18 @@ async def lifespan(app: FastAPI): os.makedirs("./data/chromadb", exist_ok=True) os.makedirs("./data/exports", exist_ok=True) init_db() + + # Start stream cleanup task + from routes.chat import clean_expired_streams + cleanup_task = asyncio.create_task(clean_expired_streams()) + logger.info("LocalMind v2.0 ready!") yield logger.info("👋 Shutting down...") + + # Cancel stream cleanup task + cleanup_task.cancel() + await asyncio.gather(cleanup_task, return_exceptions=True) app = FastAPI( diff --git a/backend/models/schemas.py b/backend/models/schemas.py index 0dff3e4..adf3f86 100644 --- a/backend/models/schemas.py +++ b/backend/models/schemas.py @@ -26,6 +26,7 @@ class ChatRequest(BaseModel): use_documents: bool = True temperature: float = Field(default=0.7, ge=0.0, le=2.0) language: str = "en" + resume_offset: Optional[int] = 0 class ChatResponse(BaseModel): diff --git a/backend/routes/chat.py b/backend/routes/chat.py index 02be5d6..722e99f 100644 --- a/backend/routes/chat.py +++ b/backend/routes/chat.py @@ -1,5 +1,7 @@ """Chat routes — /api/chat — supports normal + streaming""" +import asyncio +import time import json from types import SimpleNamespace @@ -21,6 +23,110 @@ def _retrieve_context(*args, **kwargs): rag_service = SimpleNamespace(retrieve_context=_retrieve_context) +# Global registry for active streams +ACTIVE_STREAMS = {} + + +class StreamBuffer: + def __init__(self, session_id: str, prompt: str): + self.session_id = session_id + self.prompt = prompt + self.buffer = "" + self.completed = False + self.listeners = set() + self.created_at = time.time() + self.updated_at = time.time() + self.completed_at = None + self.error = None + self.sources = [] + + +async def clean_expired_streams(): + while True: + try: + await asyncio.sleep(10) + now = time.time() + for session_id, buffer in list(ACTIVE_STREAMS.items()): + # Evict completed or failed streams after 120 seconds (2 minutes) + if (buffer.completed or buffer.error is not None) and buffer.completed_at: + if now - buffer.completed_at > 120: + ACTIVE_STREAMS.pop(session_id, None) + # Evict abandoned or running streams after 300 seconds (5 minutes) + elif now - buffer.created_at > 300: + ACTIVE_STREAMS.pop(session_id, None) + except asyncio.CancelledError: + break + except Exception: + pass + + +async def background_generator(buffer: StreamBuffer, req, context, history, sources): + try: + async for token in ollama_service.chat_stream( + message=req.message, + model=req.model, + context=context, + history=history, + language=req.language, + temperature=req.temperature, + ): + buffer.buffer += token + buffer.updated_at = time.time() + # Push token to all active listeners + for listener in list(buffer.listeners): + await listener.put({"token": token}) + + # Save successfully completed message + db_service.save_message(buffer.session_id, "assistant", buffer.buffer, sources) + buffer.completed = True + buffer.sources = sources + buffer.completed_at = time.time() + + for listener in list(buffer.listeners): + await listener.put({"done": True, "sources": sources}) + + except Exception as e: + buffer.error = str(e) + buffer.completed_at = time.time() + # Save partial response + if buffer.buffer: + db_service.save_message(buffer.session_id, "assistant", buffer.buffer, sources) + for listener in list(buffer.listeners): + await listener.put({"error": str(e)}) + + +async def stream_from_buffer(buffer: StreamBuffer, resume_offset: int): + # 1. Send already accumulated tokens from resume_offset + accumulated = buffer.buffer + if resume_offset < len(accumulated): + yield f"data: {json.dumps({'token': accumulated[resume_offset:]})}\n\n" + + # 2. If already finished, stop + if buffer.completed: + yield f"data: {json.dumps({'done': True, 'sources': buffer.sources})}\n\n" + return + if buffer.error: + yield f"data: {json.dumps({'error': buffer.error})}\n\n" + return + + # 3. Wait for new tokens + listener = asyncio.Queue() + buffer.listeners.add(listener) + try: + while True: + event = await listener.get() + if "error" in event: + yield f"data: {json.dumps({'error': event['error']})}\n\n" + break + if "token" in event: + yield f"data: {json.dumps({'token': event['token']})}\n\n" + if "done" in event: + yield f"data: {json.dumps({'done': True, 'sources': event['sources']})}\n\n" + break + finally: + buffer.listeners.discard(listener) + + @router.post("/", response_model=ChatResponse) async def chat(req: ChatRequest): """Standard (non-streaming) chat endpoint.""" @@ -58,31 +164,65 @@ async def chat_stream(req: ChatRequest): if not await ollama_service.is_ollama_running(): raise HTTPException(503, "Ollama not running. Run: `ollama serve`") - db_service.create_session(req.session_id, model=req.model) + resume_offset = req.resume_offset or 0 + is_resume = resume_offset > 0 + + # 1. Check active stream buffers + buffer = ACTIVE_STREAMS.get(req.session_id) + if buffer and buffer.prompt == req.message: + return StreamingResponse(stream_from_buffer(buffer, resume_offset), media_type="text/event-stream") + + # 2. Check completed stream in SQLite history = db_service.get_history(req.session_id) + if is_resume and history: + if history[-1]["role"] == "assistant" and len(history) >= 2: + prev_msg = history[-2] + if prev_msg["role"] == "user" and prev_msg["content"] == req.message: + async def stream_from_db(): + full_content = history[-1]["content"] + sources = [] + messages_full = db_service.get_messages_full(req.session_id) + if messages_full: + sources = messages_full[-1].get("sources", []) + if resume_offset < len(full_content): + yield f"data: {json.dumps({'token': full_content[resume_offset:]})}\n\n" + yield f"data: {json.dumps({'done': True, 'sources': sources})}\n\n" + return StreamingResponse(stream_from_db(), media_type="text/event-stream") + + # 3. Deduplicate user message + user_msg_exists = False + if history: + if history[-1]["role"] == "user" and history[-1]["content"] == req.message: + user_msg_exists = True + elif len(history) >= 2 and history[-1]["role"] == "assistant" and history[-2]["role"] == "user" and history[-2]["content"] == req.message: + user_msg_exists = True + + db_service.create_session(req.session_id, model=req.model) + if not user_msg_exists: + db_service.save_message(req.session_id, "user", req.message) + history = db_service.get_history(req.session_id) + + # 4. Clean history + cleaned_history = [] + if history and history[-1]["role"] == "assistant": + cleaned_history = history[:-1] + else: + cleaned_history = history context, sources = "", [] if req.use_documents: context, sources = rag_service.retrieve_context(req.message, req.session_id) - db_service.save_message(req.session_id, "user", req.message) - - full_reply = [] - - async def event_stream(): - async for token in ollama_service.chat_stream( - message=req.message, - model=req.model, - context=context, - history=history, - language=req.language, - temperature=req.temperature, - ): - full_reply.append(token) - yield f"data: {json.dumps({'token': token})}\n\n" + # Create new stream buffer and task + buffer = StreamBuffer(req.session_id, req.message) + ACTIVE_STREAMS[req.session_id] = buffer - complete = "".join(full_reply) - db_service.save_message(req.session_id, "assistant", complete, sources) - yield f"data: {json.dumps({'done': True, 'sources': sources})}\n\n" + asyncio.create_task(background_generator( + buffer=buffer, + req=req, + context=context, + history=cleaned_history, + sources=sources + )) - return StreamingResponse(event_stream(), media_type="text/event-stream") + return StreamingResponse(stream_from_buffer(buffer, resume_offset), media_type="text/event-stream") diff --git a/backend/tests/test_streaming_recovery.py b/backend/tests/test_streaming_recovery.py new file mode 100644 index 0000000..cdfa2f4 --- /dev/null +++ b/backend/tests/test_streaming_recovery.py @@ -0,0 +1,231 @@ +import asyncio +import tempfile +import time +import pytest +from unittest.mock import AsyncMock, patch + +import services.db_service as db +from routes.chat import chat_stream, ACTIVE_STREAMS, StreamBuffer, clean_expired_streams +from models.schemas import ChatRequest + +# Initialize a temp SQLite database for tests +_tmp = tempfile.mktemp(suffix=".db") +db.DB_PATH = _tmp +db.init_db() + + +@pytest.fixture(autouse=True) +def setup_db(): + # Clear tables before each test + with db.get_db() as conn: + conn.execute("DELETE FROM messages") + conn.execute("DELETE FROM sessions") + ACTIVE_STREAMS.clear() + + +async def mock_chat_stream(*args, **kwargs): + tokens = ["Hello", " world", "!", " How", " are", " you?"] + for t in tokens: + await asyncio.sleep(0.02) + yield t + + +@pytest.mark.asyncio +@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True) +@patch("routes.chat.ollama_service.chat_stream", side_effect=mock_chat_stream) +@patch("routes.chat.rag_service.retrieve_context", return_value=("", [])) +async def test_normal_stream_completion(mock_rag, mock_stream, mock_ollama): + req = ChatRequest( + message="Hi", + session_id="session-1", + model="llama3", + use_documents=False, + resume_offset=0 + ) + + response = await chat_stream(req) + assert response is not None + + # Read all lines from stream + chunks = [] + async for line in response.body_iterator: + if line.strip(): + chunks.append(line) + + # Should yield tokens and done event + assert len(chunks) > 0 + assert "Hello" in chunks[0] + assert "done" in chunks[-1] + + # Verify message is saved to DB + messages = db.get_messages_full("session-1") + assert len(messages) == 2 # user + assistant + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "Hello world! How are you?" + + +@pytest.mark.asyncio +@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True) +@patch("routes.chat.ollama_service.chat_stream", side_effect=mock_chat_stream) +@patch("routes.chat.rag_service.retrieve_context", return_value=("", [])) +async def test_client_disconnect_background_finishes(mock_rag, mock_stream, mock_ollama): + req = ChatRequest( + message="Hi", + session_id="session-2", + model="llama3", + use_documents=False, + resume_offset=0 + ) + + response = await chat_stream(req) + + # Simulate client reading only one chunk and disconnecting (cancelling stream) + iterator = response.body_iterator.__aiter__() + first_chunk = await iterator.__anext__() + assert "Hello" in first_chunk + + # Client disconnects -> we discard/stop reading from the iterator + # Verify that ACTIVE_STREAMS contains the buffer + buffer = ACTIVE_STREAMS.get("session-2") + assert buffer is not None + assert buffer.completed is False + + # Wait for the background generator to finish running + await asyncio.sleep(0.2) + + # Verify it finished and saved to DB + assert buffer.completed is True + messages = db.get_messages_full("session-2") + assert len(messages) == 2 + assert messages[1]["content"] == "Hello world! How are you?" + + +@pytest.mark.asyncio +@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True) +@patch("routes.chat.ollama_service.chat_stream", side_effect=mock_chat_stream) +@patch("routes.chat.rag_service.retrieve_context", return_value=("", [])) +async def test_client_disconnect_and_reconnect_during_generation(mock_rag, mock_stream, mock_ollama): + # 1. Start initial request + req1 = ChatRequest( + message="Hi", + session_id="session-3", + model="llama3", + use_documents=False, + resume_offset=0 + ) + response1 = await chat_stream(req1) + + # Read first chunk ("Hello") and disconnect + iterator1 = response1.body_iterator.__aiter__() + c1 = await iterator1.__anext__() + assert "Hello" in c1 + + # 2. Reconnect immediately with resume_offset = 5 ("Hello".length) + # Background generation is still running! + req2 = ChatRequest( + message="Hi", + session_id="session-3", + model="llama3", + use_documents=False, + resume_offset=5 + ) + response2 = await chat_stream(req2) + + # Read the rest of the stream + chunks = [] + async for line in response2.body_iterator: + if line.strip(): + chunks.append(line) + + # Verify it resumes from the next token " world" and does not duplicate "Hello" + assert "Hello" not in chunks[0] + assert "world" in chunks[0] + assert "done" in chunks[-1] + + # Wait for background task to fully complete + await asyncio.sleep(0.2) + + # Verify SQLite has exactly 1 user and 1 assistant message (no duplicates) + messages = db.get_messages_full("session-3") + assert len(messages) == 2 + assert messages[0]["role"] == "user" + assert messages[1]["role"] == "assistant" + assert messages[1]["content"] == "Hello world! How are you?" + + +@pytest.mark.asyncio +@patch("routes.chat.ollama_service.is_ollama_running", new_callable=AsyncMock, return_value=True) +@patch("routes.chat.ollama_service.chat_stream", side_effect=mock_chat_stream) +@patch("routes.chat.rag_service.retrieve_context", return_value=("", [])) +async def test_reconnect_after_generation_finished(mock_rag, mock_stream, mock_ollama): + # 1. Initial request + req1 = ChatRequest( + message="Hi", + session_id="session-4", + model="llama3", + use_documents=False, + resume_offset=0 + ) + response1 = await chat_stream(req1) + + # Disconnect immediately without reading + # Let generation finish in the background + await asyncio.sleep(0.2) + + # Evict stream from ACTIVE_STREAMS to simulate server cleanup/restart + ACTIVE_STREAMS.clear() + + # 2. Reconnect. Response should be served from SQLite database! + req2 = ChatRequest( + message="Hi", + session_id="session-4", + model="llama3", + use_documents=False, + resume_offset=11 # "Hello world".length + ) + response2 = await chat_stream(req2) + + chunks = [] + async for line in response2.body_iterator: + if line.strip(): + chunks.append(line) + + # Verify it starts after "Hello world" -> yields "!" and rest of response + assert "Hello" not in chunks[0] + assert "world" not in chunks[0] + assert "!" in chunks[0] + assert "done" in chunks[-1] + + +@pytest.mark.asyncio +async def test_ttl_cleanup(): + # 1. Add completed and active streams + ACTIVE_STREAMS["completed"] = StreamBuffer("completed", "prompt") + ACTIVE_STREAMS["completed"].completed = True + ACTIVE_STREAMS["completed"].completed_at = time.time() - 150 # 150s ago (expired) + + ACTIVE_STREAMS["completed_fresh"] = StreamBuffer("completed_fresh", "prompt") + ACTIVE_STREAMS["completed_fresh"].completed = True + ACTIVE_STREAMS["completed_fresh"].completed_at = time.time() - 30 # 30s ago (fresh) + + ACTIVE_STREAMS["active_stale"] = StreamBuffer("active_stale", "prompt") + ACTIVE_STREAMS["active_stale"].created_at = time.time() - 400 # 400s ago (stale) + + ACTIVE_STREAMS["active_fresh"] = StreamBuffer("active_fresh", "prompt") + ACTIVE_STREAMS["active_fresh"].created_at = time.time() - 50 # 50s ago (fresh) + + # 2. Run one cycle of cleaner + now = time.time() + for session_id, buffer in list(ACTIVE_STREAMS.items()): + if (buffer.completed or buffer.error is not None) and buffer.completed_at: + if now - buffer.completed_at > 120: + ACTIVE_STREAMS.pop(session_id, None) + elif now - buffer.created_at > 300: + ACTIVE_STREAMS.pop(session_id, None) + + # 3. Assert correct eviction + assert "completed" not in ACTIVE_STREAMS + assert "active_stale" not in ACTIVE_STREAMS + assert "completed_fresh" in ACTIVE_STREAMS + assert "active_fresh" in ACTIVE_STREAMS diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 6268a53..d6c6bf5 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -68,7 +68,7 @@ export default function App() { } ); } catch (e) { - setMessages(prev => prev.map(m => m.id === aiMsg.id ? { ...m, content: e.message, streaming: false } : m)); + setMessages(prev => prev.map(m => m.id === aiMsg.id ? { ...m, content: m.content + `\n\n[Connection lost: ${e.message}]`, streaming: false } : m)); } finally { setStreaming(false); } } else { setLoading(true); diff --git a/frontend/src/utils/api.js b/frontend/src/utils/api.js index 1e467bd..71194a7 100644 --- a/frontend/src/utils/api.js +++ b/frontend/src/utils/api.js @@ -38,22 +38,72 @@ export async function uploadDocument(file, session_id) { } export function streamMessage(body, onToken, onDone) { - return fetch(`${BASE}/chat/stream`, { - method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify(body), - }).then(res => { - const reader = res.body.getReader(); const decoder = new TextDecoder(); - function pump() { - return reader.read().then(({ done, value }) => { - if (done) return; - decoder.decode(value).split("\n").forEach(line => { - if (line.startsWith("data: ")) { - try { const d = JSON.parse(line.slice(6)); if (d.token) onToken(d.token); if (d.done) onDone(d.sources||[]); } catch {} + let accumulatedText = ""; + let sourcesList = []; + let doneReceived = false; + let retriesLeft = 3; + + function runStream(offset = 0) { + const requestBody = { ...body, resume_offset: offset }; + + return fetch(`${BASE}/chat/stream`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify(requestBody), + }) + .then(res => { + if (!res.ok) { + throw new Error(`HTTP error ${res.status}`); + } + const reader = res.body.getReader(); + const decoder = new TextDecoder(); + + function pump() { + return reader.read().then(({ done, value }) => { + if (done) { + if (doneReceived) { + return; + } + throw new Error("Stream closed prematurely"); } + + const text = decoder.decode(value, { stream: true }); + text.split("\n").forEach(line => { + if (line.startsWith("data: ")) { + try { + const d = JSON.parse(line.slice(6)); + if (d.token) { + accumulatedText += d.token; + onToken(d.token); + } + if (d.done) { + doneReceived = true; + sourcesList = d.sources || []; + onDone(sourcesList); + } + } catch (e) { + // Ignore parse errors + } + } + }); + return pump(); }); - return pump(); - }); - } - return pump(); - }); + } + return pump(); + }) + .catch(err => { + if (doneReceived) { + return; + } + if (retriesLeft > 0) { + retriesLeft--; + // Wait 1 second before retrying + return new Promise(resolve => setTimeout(resolve, 1000)) + .then(() => runStream(accumulatedText.length)); + } + throw err; + }); + } + + return runStream(0); } From 7d8339b0838f1b76f89a689368b20c7e9aa2f576 Mon Sep 17 00:00:00 2001 From: Praneeth Perumalla Date: Sun, 14 Jun 2026 22:34:04 +0530 Subject: [PATCH 2/2] style: fix lint warnings reported by ruff --- backend/app.py | 4 ++-- backend/routes/chat.py | 2 -- backend/tests/test_streaming_recovery.py | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/app.py b/backend/app.py index dc601c4..40f8648 100644 --- a/backend/app.py +++ b/backend/app.py @@ -5,6 +5,7 @@ import logging import os +import asyncio from contextlib import asynccontextmanager from pathlib import Path @@ -33,8 +34,7 @@ FRONTEND_DIST = Path(os.getenv("FRONTEND_DIST", "/app/frontend/dist")) -import asyncio - +# Starting lifespan code block @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Starting LocalMind v2.0...") diff --git a/backend/routes/chat.py b/backend/routes/chat.py index e6a7296..594575a 100644 --- a/backend/routes/chat.py +++ b/backend/routes/chat.py @@ -11,7 +11,6 @@ from models.schemas import ChatRequest, ChatResponse from services import ollama_service, db_service -import time import psutil def _get_memory_usage(): @@ -189,7 +188,6 @@ async def chat_stream(req: ChatRequest): if not await ollama_service.is_ollama_running(): raise HTTPException(503, "Ollama not running. Run: `ollama serve`") - first_token_time = None start_time = time.perf_counter() resume_offset = req.resume_offset or 0 diff --git a/backend/tests/test_streaming_recovery.py b/backend/tests/test_streaming_recovery.py index cdfa2f4..6c1e3f3 100644 --- a/backend/tests/test_streaming_recovery.py +++ b/backend/tests/test_streaming_recovery.py @@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, patch import services.db_service as db -from routes.chat import chat_stream, ACTIVE_STREAMS, StreamBuffer, clean_expired_streams +from routes.chat import chat_stream, ACTIVE_STREAMS, StreamBuffer from models.schemas import ChatRequest # Initialize a temp SQLite database for tests @@ -167,7 +167,7 @@ async def test_reconnect_after_generation_finished(mock_rag, mock_stream, mock_o use_documents=False, resume_offset=0 ) - response1 = await chat_stream(req1) + await chat_stream(req1) # Disconnect immediately without reading # Let generation finish in the background