From 748ef519b5e5fc7990363bd2f76ae99075428e7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:17:06 +0800 Subject: [PATCH 01/12] feat: update manager for async add --- .../tree_text_memory/organize/manager.py | 105 ++++++++++++++---- 1 file changed, 81 insertions(+), 24 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 54776134b..ddff5d7c1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -1,3 +1,4 @@ +import re import traceback import uuid @@ -19,6 +20,37 @@ logger = get_logger(__name__) +def extract_working_binding_ids(mem_items: list[TextualMemoryItem]) -> set[str]: + """ + Scan enhanced memory items for background hints like + "[working_binding:]" and collect those working memory IDs. + + We store the working<->long binding inside metadata.background when + initially adding memories in async mode, so we can later clean up + the temporary WorkingMemory nodes after mem_reader produces the + final LongTermMemory/UserMemory. + + Args: + mem_items: list of TextualMemoryItem we just added (enhanced memories) + + Returns: + A set of working memory IDs (as strings) that should be deleted. + """ + bindings: set[str] = set() + pattern = re.compile(r"\[working_binding:([0-9a-fA-F-]{36})\]") + for item in mem_items: + try: + bg = getattr(item.metadata, "background", "") or "" + except Exception: + bg = "" + if not isinstance(bg, str): + continue + match = pattern.search(bg) + if match: + bindings.add(match.group(1)) + return bindings + + class MemoryManager: def __init__( self, @@ -129,15 +161,28 @@ def _refresh_memory_size(self, user_name: str | None = None) -> None: def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = None): """ - Process and add memory to different memory types (WorkingMemory, LongTermMemory, UserMemory). - This method runs asynchronously to process each memory item. + Process and add memory to different memory types. + + Behavior: + 1. Always create a WorkingMemory node from `memory` and get its node id. + 2. If `memory.metadata.memory_type` is "LongTermMemory" or "UserMemory", + also create a corresponding long/user node. + - In async mode, that long/user node's metadata will include + `working_binding` in `background` which records the WorkingMemory + node id created in step 1. + 3. Return ONLY the ids of the long/user nodes (NOT the working node id), + which preserves the previous external contract of `add()`. """ ids: list[str] = [] futures = [] + working_id = str(uuid.uuid4()) + with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - f_working = ex.submit(self._add_memory_to_db, memory, "WorkingMemory", user_name) - futures.append(f_working) + f_working = ex.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id + ) + futures.append(("working", f_working)) if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): f_graph = ex.submit( @@ -145,13 +190,14 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non memory=memory, memory_type=memory.metadata.memory_type, user_name=user_name, + working_binding=working_id, ) - futures.append(f_graph) + futures.append(("long", f_graph)) - for fut in as_completed(futures): + for kind, fut in futures: try: res = fut.result() - if isinstance(res, str) and res: + if kind != "working" and isinstance(res, str) and res: ids.append(res) except Exception: logger.warning("Parallel memory processing failed:\n%s", traceback.format_exc()) @@ -159,39 +205,50 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non return ids def _add_memory_to_db( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + forced_id: str | None = None, ) -> str: """ Add a single memory item to the graph store, with FIFO logic for WorkingMemory. + If forced_id is provided, use that as the node id. """ metadata = memory.metadata.model_copy(update={"memory_type": memory_type}).model_dump( exclude_none=True ) metadata["updated_at"] = datetime.now().isoformat() - working_memory = TextualMemoryItem(memory=memory.memory, metadata=metadata) - + node_id = forced_id or str(uuid.uuid4()) + working_memory = TextualMemoryItem(id=node_id, memory=memory.memory, metadata=metadata) # Insert node into graph self.graph_store.add_node(working_memory.id, working_memory.memory, metadata, user_name) + return node_id def _add_to_graph_memory( - self, memory: TextualMemoryItem, memory_type: str, user_name: str | None = None + self, + memory: TextualMemoryItem, + memory_type: str, + user_name: str | None = None, + working_binding: str | None = None, ): """ Generalized method to add memory to a graph-based memory type (e.g., LongTermMemory, UserMemory). - - Parameters: - - memory: memory item to insert - - memory_type: "LongTermMemory" | "UserMemory" - - similarity_threshold: deduplication threshold - - topic_summary_prefix: summary node id prefix if applicable - - enable_summary_link: whether to auto-link to a summary node """ node_id = str(uuid.uuid4()) # Step 2: Add new node to graph + metadata_dict = memory.metadata.model_dump(exclude_none=True) + if working_binding and ("mode:fast" in metadata_dict["tags"]): + prev_bg = metadata_dict.get("background", "") or "" + binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" + if prev_bg: + metadata_dict["background"] = prev_bg + " || " + binding_line + else: + metadata_dict["background"] = binding_line self.graph_store.add_node( node_id, memory.memory, - memory.metadata.model_dump(exclude_none=True), + metadata_dict, user_name=user_name, ) self.reorganizer.add_message( @@ -282,11 +339,11 @@ def _ensure_structure_path( # Step 3: Return this structure node ID as the parent_id return node_id - def remove_and_refresh_memory(self): - self._cleanup_memories_if_needed() - self._refresh_memory_size() + def remove_and_refresh_memory(self, user_name: str | None = None): + self._cleanup_memories_if_needed(user_name=user_name) + self._refresh_memory_size(user_name=user_name) - def _cleanup_memories_if_needed(self) -> None: + def _cleanup_memories_if_needed(self, user_name: str | None = None) -> None: """ Only clean up memories if we're close to or over the limit. This reduces unnecessary database operations. @@ -301,7 +358,7 @@ def _cleanup_memories_if_needed(self) -> None: if current_count >= threshold: try: self.graph_store.remove_oldest_memory( - memory_type=memory_type, keep_latest=limit + memory_type=memory_type, keep_latest=limit, user_name=user_name ) logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") except Exception: From aba85a5281ac4d1f0d9645921ac9d33d403b7f8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:20:34 +0800 Subject: [PATCH 02/12] feat: modify tree and simple_tree, TODO: STILL NOT ALIGN IN SOME FUNCTIONS --- src/memos/memories/textual/simple_tree.py | 26 +---------------------- src/memos/memories/textual/tree.py | 18 ++++++++++------ 2 files changed, 13 insertions(+), 31 deletions(-) diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 52bf62c6d..37ea50462 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -44,6 +44,7 @@ def __init__( """Initialize memory with the given configuration.""" time_start = time.time() self.config: TreeTextMemoryConfig = config + self.mode = self.config.mode self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") @@ -79,20 +80,6 @@ def __init__( logger.info("No internet retriever configured") logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - def add( - self, memories: list[TextualMemoryItem | dict[str, Any]], user_name: str | None = None - ) -> list[str]: - """Add memories. - Args: - memories: List of TextualMemoryItem objects or dictionaries to add. - Later: - memory_items = [TextualMemoryItem(**m) if isinstance(m, dict) else m for m in memories] - metadata = extract_metadata(memory_items, self.extractor_llm) - plan = plan_memory_operations(memory_items, metadata, self.graph_store) - execute_plan(memory_items, metadata, plan, self.graph_store) - """ - return self.memory_manager.add(memories, user_name=user_name) - def replace_working_memory( self, memories: list[TextualMemoryItem], user_name: str | None = None ) -> None: @@ -271,17 +258,6 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: - """Get all memories. - Returns: - list[TextualMemoryItem]: List of all memories. - """ - all_items = self.graph_store.export_graph() - return all_items - - def delete(self, memory_ids: list[str]) -> None: - raise NotImplementedError - def delete_all(self) -> None: """Delete all memories and their relationships from the graph store.""" try: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index fccd83fa6..14e6ec334 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -81,12 +81,18 @@ def __init__(self, config: TreeTextMemoryConfig): else: logger.info("No internet retriever configured") - def add(self, memories: list[TextualMemoryItem | dict[str, Any]], **kwargs) -> list[str]: + def add( + self, + memories: list[TextualMemoryItem | dict[str, Any]], + user_name: str | None = None, + **kwargs, + ) -> list[str]: """Add memories. Args: memories: List of TextualMemoryItem objects or dictionaries to add. + user_name: optional user_name """ - return self.memory_manager.add(memories, mode=self.mode) + return self.memory_manager.add(memories, user_name=user_name, mode=self.mode) def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: self.memory_manager.replace_working_memory(memories) @@ -262,21 +268,21 @@ def get(self, memory_id: str) -> TextualMemoryItem: def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: raise NotImplementedError - def get_all(self) -> dict: + def get_all(self, user_name: str | None = None) -> dict: """Get all memories. Returns: list[TextualMemoryItem]: List of all memories. """ - all_items = self.graph_store.export_graph() + all_items = self.graph_store.export_graph(user_name=user_name) return all_items - def delete(self, memory_ids: list[str]) -> None: + def delete(self, memory_ids: list[str], user_name: str | None = None) -> None: """Hard delete: permanently remove nodes and their edges from the graph.""" if not memory_ids: return for mid in memory_ids: try: - self.graph_store.delete_node(mid) + self.graph_store.delete_node(mid, user_name=user_name) except Exception as e: logger.warning(f"TreeTextMemory.delete_hard: failed to delete {mid}: {e}") From fd2c4aafd70c9752174e5bfe5e7c5d85e5e49a19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:23:54 +0800 Subject: [PATCH 03/12] feat: modify schedule: add optional user_name in schedule message; modify user-name related graph query in scheduler --- src/memos/mem_scheduler/base_scheduler.py | 4 ++ src/memos/mem_scheduler/general_scheduler.py | 47 ++++++++++++++++--- .../mem_scheduler/schemas/message_schemas.py | 7 +++ 3 files changed, 52 insertions(+), 6 deletions(-) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 3958ee382..0360396af 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -134,6 +134,7 @@ def initialize_modules( chat_llm: BaseLLM, process_llm: BaseLLM | None = None, db_engine: Engine | None = None, + mem_reader=None, ): if process_llm is None: process_llm = chat_llm @@ -150,6 +151,9 @@ def initialize_modules( self.dispatcher_monitor = SchedulerDispatcherMonitor(config=self.config) self.retriever = SchedulerRetriever(process_llm=self.process_llm, config=self.config) + if mem_reader: + self.mem_reader = mem_reader + if self.enable_parallel_dispatch: self.dispatcher_monitor.initialize(dispatcher=self.dispatcher) self.dispatcher_monitor.start() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d84ebb242..03b67261f 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -249,6 +249,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -272,6 +273,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -296,6 +298,7 @@ def _process_memories_with_reader( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reader for enhanced memory processing. @@ -329,6 +332,18 @@ def _process_memories_with_reader( logger.warning("No valid memory items found for processing") return + # parse working_binding ids from the *original* memory_items (the raw items created in /add) + # these still carry metadata.background with "[working_binding:...]" so we can know + # which WorkingMemory clones should be cleaned up later. + from memos.memories.textual.tree_text_memory.organize.manager import ( + extract_working_binding_ids, + ) + + bindings_to_delete = extract_working_binding_ids(memory_items) + logger.info( + f"Extracted {len(bindings_to_delete)} working_binding ids to cleanup: {list(bindings_to_delete)}" + ) + # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") @@ -352,7 +367,7 @@ def _process_memories_with_reader( # Add the enhanced memories back to the memory system if flattened_memories: - enhanced_mem_ids = text_mem.add(flattened_memories) + enhanced_mem_ids = text_mem.add(flattened_memories, user_name=user_name) logger.info( f"Added {len(enhanced_mem_ids)} enhanced memories: {enhanced_mem_ids}" ) @@ -361,9 +376,26 @@ def _process_memories_with_reader( else: logger.info("mem_reader returned no processed memories") - text_mem.delete(mem_ids) - logger.info("Delete raw mem_ids") - text_mem.memory_manager.remove_and_refresh_memory() + # build full delete list: + # - original raw mem_ids (temporary fast memories) + # - any bound working memories referenced by the enhanced memories + delete_ids = list(mem_ids) + if bindings_to_delete: + delete_ids.extend(list(bindings_to_delete)) + # deduplicate + delete_ids = list(dict.fromkeys(delete_ids)) + if delete_ids: + try: + text_mem.delete(delete_ids, user_name=user_name) + logger.info( + f"Delete raw/working mem_ids: {delete_ids} for user_name: {user_name}" + ) + except Exception as e: + logger.warning(f"Failed to delete some mem_ids {delete_ids}: {e}") + else: + logger.info("No mem_ids to delete (nothing to cleanup)") + + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") @@ -381,6 +413,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id mem_cube = message.mem_cube content = message.content + user_name = message.user_name # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -404,6 +437,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, mem_cube=mem_cube, text_mem=text_mem, + user_name=user_name, ) logger.info( @@ -428,6 +462,7 @@ def _process_memories_with_reorganize( mem_cube_id: str, mem_cube: GeneralMemCube, text_mem: TreeTextMemory, + user_name: str, ) -> None: """ Process memories using mem_reorganize for enhanced memory processing. @@ -454,7 +489,7 @@ def _process_memories_with_reorganize( memory_item = text_mem.get(mem_id) memory_items.append(memory_item) except Exception as e: - logger.warning(f"Failed to get memory {mem_id}: {e}") + logger.warning(f"Failed to get memory {mem_id}: {e}|{traceback.format_exc()}") continue if not memory_items: @@ -463,7 +498,7 @@ def _process_memories_with_reorganize( # Use mem_reader to process the memories logger.info(f"Processing {len(memory_items)} memories with mem_reader") - text_mem.memory_manager.remove_and_refresh_memory() + text_mem.memory_manager.remove_and_refresh_memory(user_name=user_name) logger.info("Remove and Refresh Memories") logger.debug(f"Finished add {user_id} memory: {mem_ids}") diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index bd3155a96..9cdb6823d 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -42,6 +42,10 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" ) + user_name: str | None = Field( + default=None, + description="user name / display name (optional)", + ) # Pydantic V2 model configuration model_config = ConfigDict( @@ -60,6 +64,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "mem_cube": "obj of GeneralMemCube", # Added mem_cube example "content": "sample content", # Example message content "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example + "user_name": "Alice", # Added username example } }, ) @@ -81,6 +86,7 @@ def to_dict(self) -> dict: "cube": "Not Applicable", # Custom cube serialization "content": self.content, "timestamp": self.timestamp.isoformat(), + "user_name": self.user_name, } @classmethod @@ -94,6 +100,7 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), + user_name=data.get("user_name"), ) From bd5806946a246d4920d0533f04fe2817b99a5e79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:50:55 +0800 Subject: [PATCH 04/12] feat: finishe server router for async mode --- src/memos/api/routers/server_router.py | 255 +++++++++++++++++++++++-- 1 file changed, 234 insertions(+), 21 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index f50d3ad75..e73b3345a 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,10 +1,14 @@ +import json import os +import time import traceback from concurrent.futures import ThreadPoolExecutor +from datetime import datetime from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse from memos.api.config import APIConfig from memos.api.product_models import ( @@ -32,8 +36,12 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, SearchMode, ) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, ExtractorConfigFactory, @@ -233,6 +241,7 @@ def init_server(): chat_llm=llm, process_llm=mem_reader.llm, db_engine=BaseDBManager.create_default_sqlite_engine(), + mem_reader=mem_reader, ) mem_scheduler.current_mem_cube = naive_mem_cube mem_scheduler.start() @@ -479,6 +488,13 @@ def add_memories(add_req: APIADDRequest): if not target_session_id: target_session_id = "default_session" + # If text memory backend works in async mode, submit tasks to scheduler + try: + sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") + except Exception: + sync_mode = "sync" + logger.info(f"Add sync_mode mode is: {sync_mode}") + def _process_text_mem() -> list[dict[str, str]]: memories_local = mem_reader.get_memory( [add_req.messages], @@ -487,6 +503,7 @@ def _process_text_mem() -> list[dict[str, str]]: "user_id": add_req.user_id, "session_id": target_session_id, }, + mode="fast" if sync_mode == "async" else "fine", ) flattened_local = [mm for m in memories_local for mm in m] logger.info(f"Memory extraction completed for user {add_req.user_id}") @@ -498,6 +515,34 @@ def _process_text_mem() -> list[dict[str, str]]: f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " f"in session {add_req.session_id}: {mem_ids_local}" ) + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_read]) + logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}") + except Exception as e: + logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_add]) return [ { "memory": memory.memory, @@ -510,27 +555,46 @@ def _process_text_mem() -> list[dict[str, str]]: def _process_pref_mem() -> list[dict[str, str]]: if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] + # Follow async behavior similar to core.py: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + mem_scheduler.submit_messages(messages=[message_item_pref]) + logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] with ThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) @@ -544,6 +608,155 @@ def _process_pref_mem() -> list[dict[str, str]]: ) +@router.get("/scheduler/status", summary="Get scheduler running task count") +def scheduler_status(): + """ + Return current running tasks count from scheduler dispatcher. + Shape is consistent with /scheduler/wait. + """ + try: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + now_ts = time.time() + + return { + "message": "ok", + "data": { + "running_tasks": running_count, + "timestamp": now_ts, + }, + } + + except Exception as err: + logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + +@router.post("/scheduler/wait", summary="Wait until scheduler is idle") +def scheduler_wait(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Block until scheduler has no running tasks, or timeout. + We return a consistent structured payload so callers can + tell whether this was a clean flush or a timeout. + + Args: + timeout_seconds: max seconds to wait + poll_interval: seconds between polls + """ + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # success -> scheduler is idle + if running_count == 0: + return { + "message": "idle", + "data": { + "running_tasks": 0, + "waited_seconds": round(elapsed, 3), + "timed_out": False, + }, + } + + # timeout check + if elapsed > timeout_seconds: + return { + "message": "timeout", + "data": { + "running_tasks": running_count, + "waited_seconds": round(elapsed, 3), + "timed_out": True, + }, + } + + time.sleep(poll_interval) + + except Exception as err: + logger.error( + "Failed while waiting for scheduler: %s", + traceback.format_exc(), + ) + raise HTTPException( + status_code=500, + detail="Failed while waiting for scheduler", + ) from err + + +@router.get("/scheduler/wait/stream", summary="Stream scheduler progress (SSE)") +def scheduler_wait_stream(timeout_seconds: float = 120.0, poll_interval: float = 0.2): + """ + Stream scheduler progress via Server-Sent Events (SSE). + + Contract: + - We emit periodic heartbeat frames while tasks are still running. + - Each heartbeat frame is JSON, prefixed with "data: ". + - On final frame, we include status = "idle" or "timeout" and timed_out flag, + with the same semantics as /scheduler/wait. + + Example curl: + curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5" + """ + + def event_generator(): + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks() + running_count = len(running) + elapsed = time.time() - start + + # heartbeat frame + heartbeat_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "running" if running_count > 0 else "idle", + } + yield "data: " + json.dumps(heartbeat_payload, ensure_ascii=False) + "\n\n" + + # scheduler is idle -> final frame + break + if running_count == 0: + final_payload = { + "running_tasks": 0, + "elapsed_seconds": round(elapsed, 3), + "status": "idle", + "timed_out": False, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + # timeout -> final frame + break + if elapsed > timeout_seconds: + final_payload = { + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "timeout", + "timed_out": True, + } + yield "data: " + json.dumps(final_payload, ensure_ascii=False) + "\n\n" + break + + time.sleep(poll_interval) + + except Exception as e: + err_payload = { + "status": "error", + "detail": "stream_failed", + "exception": str(e), + } + logger.error( + "Failed streaming scheduler wait: %s: %s", + e, + traceback.format_exc(), + ) + yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") + + @router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" From 6d16647d74cb583e3e13cfa7bf99f85942a29fe5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:52:13 +0800 Subject: [PATCH 05/12] feat: update graph db --- src/memos/graph_dbs/nebular.py | 16 +++++++--------- src/memos/graph_dbs/neo4j.py | 1 + 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 12b493e58..13fb32502 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -439,6 +439,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ try: user_name = user_name if user_name else self.config.user_name @@ -685,14 +686,14 @@ def get_node( Returns: dict: Node properties as key-value pairs, or None if not found. """ - user_name = user_name if user_name else self.config.user_name - filter_clause = f'n.user_name = "{user_name}" AND n.id = "{id}"' + filter_clause = f'n.id = "{id}"' return_fields = self._build_return_fields(include_embedding) gql = f""" MATCH (n@Memory) WHERE {filter_clause} RETURN {return_fields} """ + logger.info(f"in get node: {filter_clause}") try: result = self.execute_query(gql) @@ -730,16 +731,13 @@ def get_nodes( """ if not ids: return [] - - user_name = user_name if user_name else self.config.user_name - where_user = f" AND n.user_name = '{user_name}'" # Safe formatting of the ID list id_list = ",".join(f'"{_id}"' for _id in ids) return_fields = self._build_return_fields(include_embedding) query = f""" MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) - WHERE n.id IN [{id_list}] {where_user} + WHERE n.id IN [{id_list}] RETURN {return_fields} """ nodes = [] @@ -1497,10 +1495,10 @@ def _ensure_space_exists(cls, tmp_client, cfg): return try: - res = tmp_client.execute("SHOW GRAPHS;") + res = tmp_client.execute("SHOW GRAPHS") existing = {row.values()[0].as_string() for row in res} if db_name not in existing: - tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type") logger.info(f"✅ Graph `{db_name}` created before session binding.") else: logger.debug(f"Graph `{db_name}` already exists.") @@ -1551,7 +1549,7 @@ def _ensure_database_exists(self): """ self.execute_query(create_tag, auto_set_db=False) else: - describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name};" + describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}" desc_result = self.execute_query(describe_query, auto_set_db=False) memory_fields = [] diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index fd3a1ba22..f3a36a887 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -149,6 +149,7 @@ def remove_oldest_memory( Args: memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory'). keep_latest (int): Number of latest WorkingMemory entries to keep. + user_name(str): optional user_name. """ user_name = user_name if user_name else self.config.user_name query = f""" From 9794321fb821814233c8622ceecdb65d21e8daa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:53:14 +0800 Subject: [PATCH 06/12] fix: add label in core --- src/memos/mem_os/core.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index ec8a673d7..d253bd220 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -779,16 +779,16 @@ def process_textual_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) def process_preference_memory(): if ( @@ -878,15 +878,16 @@ def process_preference_memory(): timestamp=datetime.utcnow(), ) self.mem_scheduler.submit_messages(messages=[message_item]) - message_item = ScheduleMessageItem( - user_id=target_user_id, - mem_cube_id=mem_cube_id, - mem_cube=mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids), - timestamp=datetime.utcnow(), - ) - self.mem_scheduler.submit_messages(messages=[message_item]) + else: + message_item = ScheduleMessageItem( + user_id=target_user_id, + mem_cube_id=mem_cube_id, + mem_cube=mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) # user doc input if ( From e636df29b95c8e18a12ad7bd12c6c739c5d0c5aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 16:57:52 +0800 Subject: [PATCH 07/12] feat: add tree mode in config --- src/memos/api/config.py | 1 + src/memos/memories/textual/simple_tree.py | 1 + src/memos/memories/textual/tree.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 6de013313..7bb81665e 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -672,6 +672,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "LongTermMemory": os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6), "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, + "mode": os.getenv("ASYNC_MODE", "sync"), }, }, "act_mem": {} diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 37ea50462..8d07522cd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -45,6 +45,7 @@ def __init__( time_start = time.time() self.config: TreeTextMemoryConfig = config self.mode = self.config.mode + logger.info(f"Tree mode is {self.mode}") self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 14e6ec334..472bed219 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -34,6 +34,8 @@ def __init__(self, config: TreeTextMemoryConfig): """Initialize memory with the given configuration.""" # Set mode from class default or override if needed self.mode = config.mode + logger.info(f"Tree mode is {self.mode}") + self.config: TreeTextMemoryConfig = config self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = LLMFactory.from_config( config.extractor_llm From 9b8873ba7191f23041e0a336d61c2aee05ae2c74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 17:06:32 +0800 Subject: [PATCH 08/12] feat: default llm token 8000 --- src/memos/api/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 7bb81665e..92df1ecf8 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -23,7 +23,7 @@ def get_openai_config() -> dict[str, Any]: return { "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-4o-mini"), "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.8")), - "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "1024")), + "max_tokens": int(os.getenv("MOS_MAX_TOKENS", "8000")), "top_p": float(os.getenv("MOS_TOP_P", "0.9")), "top_k": int(os.getenv("MOS_TOP_K", "50")), "remove_think_prefix": True, From e1d2d66f5e72624275502adf3e17f80a9b6794f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 17:13:32 +0800 Subject: [PATCH 09/12] fix: thread --- src/memos/api/routers/server_router.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index e73b3345a..9cab7ae38 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -3,7 +3,6 @@ import time import traceback -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from typing import TYPE_CHECKING, Any @@ -26,6 +25,7 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -379,7 +379,7 @@ def _search_pref(): ) return [_format_memory_item(data) for data in results] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_search_text) pref_future = executor.submit(_search_pref) text_formatted_memories = text_future.result() @@ -596,7 +596,7 @@ def _process_pref_mem() -> list[dict[str, str]]: for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) pref_future = executor.submit(_process_pref_mem) text_response_data = text_future.result() From cef463f61e8e70ec8055d2ceccd96e800712846e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 17:19:28 +0800 Subject: [PATCH 10/12] feat: search mode in client: fast --- evaluation/scripts/utils/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index e1bdd54e9..9b686a131 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, - "mode": "mixture", + "mode": "fast", "handle_pref_mem": False, }, ensure_ascii=False, From 4f579a475ed3c23c991a48e9c968bdcb9ce68f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 17:33:50 +0800 Subject: [PATCH 11/12] tests: fix --- src/memos/graph_dbs/nebular.py | 1 - .../memories/textual/tree_text_memory/organize/manager.py | 3 ++- tests/memories/textual/test_tree.py | 8 ++++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 13fb32502..89b58f417 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -693,7 +693,6 @@ def get_node( WHERE {filter_clause} RETURN {return_fields} """ - logger.info(f"in get node: {filter_clause}") try: result = self.execute_query(gql) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index ddff5d7c1..47cbf4ed1 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -238,7 +238,8 @@ def _add_to_graph_memory( node_id = str(uuid.uuid4()) # Step 2: Add new node to graph metadata_dict = memory.metadata.model_dump(exclude_none=True) - if working_binding and ("mode:fast" in metadata_dict["tags"]): + tags = metadata_dict.get("tags") or [] + if working_binding and ("mode:fast" in tags): prev_bg = metadata_dict.get("background", "") or "" binding_line = f"[working_binding:{working_binding}] direct built from raw inputs" if prev_bg: diff --git a/tests/memories/textual/test_tree.py b/tests/memories/textual/test_tree.py index 772a79d78..a72709ec5 100644 --- a/tests/memories/textual/test_tree.py +++ b/tests/memories/textual/test_tree.py @@ -66,7 +66,9 @@ def test_add_calls_manager(mock_tree_text_memory): metadata=TreeNodeTextualMemoryMetadata(updated_at=None), ) mock_tree_text_memory.add([mock_item]) - mock_tree_text_memory.memory_manager.add.assert_called_once_with([mock_item], mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + [mock_item], user_name=None, mode="sync" + ) def test_get_working_memory_sorted(mock_tree_text_memory): @@ -161,4 +163,6 @@ def test_add_returns_ids(mock_tree_text_memory): result = mock_tree_text_memory.add(mock_items) assert result == dummy_ids - mock_tree_text_memory.memory_manager.add.assert_called_once_with(mock_items, mode="sync") + mock_tree_text_memory.memory_manager.add.assert_called_once_with( + mock_items, user_name=None, mode="sync" + ) From ef879e3c8c132af00d09fd9d6b1d6f8cdacf292c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B8=AD=E9=98=B3=E9=98=B3?= Date: Wed, 29 Oct 2025 21:13:57 +0800 Subject: [PATCH 12/12] fix: add some log for memory_size in manager --- .../memories/textual/tree_text_memory/organize/manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 47cbf4ed1..01ccc382b 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -351,9 +351,10 @@ def _cleanup_memories_if_needed(self, user_name: str | None = None) -> None: """ cleanup_threshold = 0.8 # Clean up when 80% full + logger.info(f"self.memory_size: {self.memory_size}") for memory_type, limit in self.memory_size.items(): current_count = self.current_memory_size.get(memory_type, 0) - threshold = int(limit * cleanup_threshold) + threshold = int(int(limit) * cleanup_threshold) # Only clean up if we're at or above the threshold if current_count >= threshold: