diff --git a/.gitignore b/.gitignore
index ab7848f74..c972d59de 100644
--- a/.gitignore
+++ b/.gitignore
@@ -234,6 +234,6 @@ apps/openwork-memos-integration/apps/desktop/public/assets/usecases/
# Outputs and Evaluation Results
outputs
-evaluation/data/temporal_locomo
+evaluation/data/
test_add_pipeline.py
test_file_pipeline.py
diff --git a/Makefile b/Makefile
index 788504a73..eb22e241d 100644
--- a/Makefile
+++ b/Makefile
@@ -36,7 +36,7 @@ pre_commit:
poetry run pre-commit run -a
serve:
- poetry run uvicorn memos.api.start_api:app
+ poetry run uvicorn memos.api.server_api:app
openapi:
poetry run memos export_openapi --output docs/openapi.json
diff --git a/README.md b/README.md
index a7b05d683..1c7d5bd93 100644
--- a/README.md
+++ b/README.md
@@ -75,7 +75,7 @@
- [**72% lower token usage**](https://x.com/MemOS_dev/status/2020854044583924111) โ intelligent memory retrieval instead of loading full chat history
- [**Multi-agent memory sharing**](https://x.com/MemOS_dev/status/2020538135487062094) โ multi-instance agents share memory via same user_id, automatic context handoff
-Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
+Get your API key: [MemOS Dashboard](https://memos-dashboard.openmem.net/cn/login/)
Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)
### ๐ง Local Plugin โ 100% On-Device Memory
@@ -84,7 +84,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **Hybrid search + task & skill evolution** โ FTS5 + vector search, auto task summarization, reusable skills that self-upgrade
- **Multi-agent collaboration + Memory Viewer** โ memory isolation, skill sharing, full web dashboard with 7 management pages
- ๐ [Homepage](https://memos-claw.openmem.net) ยท
+ ๐ [Homepage](https://memos-claw.openmem.net) ยท
๐ [Documentation](https://memos-claw.openmem.net/docs/index.html) ยท ๐ฆ [NPM](https://www.npmjs.com/package/@memtensor/memos-local-openclaw-plugin)
## ๐ MemOS: Memory Operating System for AI Agents
@@ -104,10 +104,10 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
### News
-- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
+- **2026-03-08** ยท ๐ฆ **MemOS OpenClaw Plugin โ Cloud & Local**
Official OpenClaw memory plugins launched. **Cloud Plugin**: hosted memory service with 72% lower token usage and multi-agent memory sharing ([MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/MemOS-Cloud-OpenClaw-Plugin)). **Local Plugin** (`v1.0.0`): 100% on-device memory with persistent SQLite, hybrid search (FTS5 + vector), task summarization & skill evolution, multi-agent collaboration, and a full Memory Viewer dashboard.
-- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
+- **2025-12-24** ยท ๐ **MemOS v2.0: Stardust (ๆๅฐ) Release**
Comprehensive KB (doc/URL parsing + cross-project sharing), memory feedback & precise deletion, multi-modal memory (images/charts), tool memory for agent planning, Redis Streams scheduling + DB optimizations, streaming/non-streaming chat, MCP upgrade, and lightweight quick/full deployment.
โจ New Features
@@ -155,7 +155,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- **2025-08-07** ยท ๐ **MemOS v1.0.0 (MemCube) Release**
- First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, NebulaGraph support, improved search capabilities, and the official Playground launch.
+ First MemCube release with a word-game demo, LongMemEval evaluation, BochaAISearchRetriever integration, improved search capabilities, and the official Playground launch.
โจ New Features
@@ -176,7 +176,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
**Plaintext Memory**
- Integrated internet search with Bocha.
- - Added support for Nebula database.
+ - Expanded graph database support.
- Added contextual understanding for the tree-structured plaintext memory search interface.
@@ -188,7 +188,7 @@ Full tutorial โ [MemOS-Cloud-OpenClaw-Plugin](https://github.com/MemTensor/Mem
- Fixed the concat_cache method.
**Plaintext Memory**
- - Fixed Nebula search-related issues.
+ - Fixed graph search-related issues.
diff --git a/docker/.env.example b/docker/.env.example
index 3674cd69b..08f9cf1db 100644
--- a/docker/.env.example
+++ b/docker/.env.example
@@ -127,7 +127,7 @@ MEMSCHEDULER_USE_REDIS_QUEUE=false
## Graph / vector stores
# Neo4j database selection mode
-NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | nebular | polardb
+NEO4J_BACKEND=neo4j-community # neo4j-community | neo4j | polardb | postgres
# Neo4j database url
NEO4J_URI=bolt://localhost:7687 # required when backend=neo4j*
# Neo4j database user
diff --git a/examples/mem_agent/deepsearch_example.py b/examples/mem_agent/deepsearch_example.py
index 6dbe202c2..d14b6d687 100644
--- a/examples/mem_agent/deepsearch_example.py
+++ b/examples/mem_agent/deepsearch_example.py
@@ -47,7 +47,7 @@ def build_minimal_components():
# Build component configurations using APIConfig methods (like config_builders.py)
- # Graph DB configuration - using APIConfig.get_nebular_config()
+ # Graph DB configuration - using APIConfig graph DB helpers
graph_db_backend = os.getenv("NEO4J_BACKEND", "polardb").lower()
graph_db_backend_map = {
"polardb": APIConfig.get_polardb_config(),
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index 87f1efd8e..2e8bae57c 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -741,21 +741,6 @@ def get_neo4j_shared_config(user_id: str | None = None) -> dict[str, Any]:
"embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
}
- @staticmethod
- def get_nebular_config(user_id: str | None = None) -> dict[str, Any]:
- """Get Nebular configuration."""
- return {
- "uri": json.loads(os.getenv("NEBULAR_HOSTS", '["localhost"]')),
- "user": os.getenv("NEBULAR_USER", "root"),
- "password": os.getenv("NEBULAR_PASSWORD", "xxxxxx"),
- "space": os.getenv("NEBULAR_SPACE", "shared-tree-textual-memory"),
- "user_name": f"memos{user_id.replace('-', '')}",
- "use_multi_db": False,
- "auto_create": True,
- "embedding_dimension": int(os.getenv("EMBEDDING_DIMENSION", 3072)),
- }
-
- @staticmethod
def get_milvus_config():
return {
"collection_name": [
@@ -1103,7 +1088,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id)
neo4j_config = APIConfig.get_neo4j_config(user_id)
- nebular_config = APIConfig.get_nebular_config(user_id)
polardb_config = APIConfig.get_polardb_config(user_id)
internet_config = (
APIConfig.get_internet_config()
@@ -1114,7 +1098,6 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1144,9 +1127,9 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
"reorganize": os.getenv("MOS_ENABLE_REORGANIZE", "false").lower()
== "true",
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1169,7 +1152,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
default_mem_cube = GeneralMemCube(default_cube_config)
return default_config, default_mem_cube
@@ -1188,13 +1171,11 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
openai_config = APIConfig.get_openai_config()
neo4j_community_config = APIConfig.get_neo4j_community_config(user_id="default")
neo4j_config = APIConfig.get_neo4j_config(user_id="default")
- nebular_config = APIConfig.get_nebular_config(user_id="default")
polardb_config = APIConfig.get_polardb_config(user_id="default")
postgres_config = APIConfig.get_postgres_config(user_id="default")
graph_db_backend_map = {
"neo4j-community": neo4j_community_config,
"neo4j": neo4j_config,
- "nebular": nebular_config,
"polardb": polardb_config,
"postgres": postgres_config,
}
@@ -1227,9 +1208,9 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
== "true",
"internet_retriever": internet_config,
"memory_size": {
- "WorkingMemory": int(os.getenv("NEBULAR_WORKING_MEMORY", 20)),
- "LongTermMemory": int(os.getenv("NEBULAR_LONGTERM_MEMORY", 1e6)),
- "UserMemory": int(os.getenv("NEBULAR_USER_MEMORY", 1e6)),
+ "WorkingMemory": int(os.getenv("MOS_WORKING_MEMORY", 20)),
+ "LongTermMemory": int(os.getenv("MOS_LONGTERM_MEMORY", 1e6)),
+ "UserMemory": int(os.getenv("MOS_USER_MEMORY", 1e6)),
},
"search_strategy": {
"fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"),
@@ -1253,4 +1234,4 @@ def get_default_cube_config() -> "GeneralMemCubeConfig | None":
}
)
else:
- raise ValueError(f"Invalid Neo4j backend: {graph_db_backend}")
+ raise ValueError(f"Invalid graph DB backend: {graph_db_backend}")
diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py
index e071eacb3..eb982fd06 100644
--- a/src/memos/api/handlers/base_handler.py
+++ b/src/memos/api/handlers/base_handler.py
@@ -36,7 +36,6 @@ def __init__(
vector_db: Any | None = None,
internet_retriever: Any | None = None,
memory_manager: Any | None = None,
- mos_server: Any | None = None,
feedback_server: Any | None = None,
**kwargs,
):
@@ -54,7 +53,6 @@ def __init__(
vector_db: Vector database instance
internet_retriever: Internet retriever instance
memory_manager: Memory manager instance
- mos_server: MOS server instance
**kwargs: Additional dependencies
"""
self.llm = llm
@@ -68,7 +66,6 @@ def __init__(
self.vector_db = vector_db
self.internet_retriever = internet_retriever
self.memory_manager = memory_manager
- self.mos_server = mos_server
self.feedback_server = feedback_server
# Store any additional dependencies
@@ -158,11 +155,6 @@ def vector_db(self):
"""Get vector database instance."""
return self.deps.vector_db
- @property
- def mos_server(self):
- """Get MOS server instance."""
- return self.deps.mos_server
-
@property
def deepsearch_agent(self):
"""Get deepsearch agent instance."""
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
index 7894ff7dc..a01fffef8 100644
--- a/src/memos/api/handlers/component_init.py
+++ b/src/memos/api/handlers/component_init.py
@@ -28,7 +28,6 @@
from memos.log import get_logger
from memos.mem_cube.navie import NaiveMemCube
from memos.mem_feedback.simple_feedback import SimpleMemFeedback
-from memos.mem_os.product_server import MOSServer
from memos.mem_reader.factory import MemReaderFactory
from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
@@ -211,15 +210,6 @@ def init_server() -> dict[str, Any]:
logger.debug("Text memory initialized")
- # Initialize MOS Server
- mos_server = MOSServer(
- mem_reader=mem_reader,
- llm=llm,
- online_bot=False,
- )
-
- logger.debug("MOS server initialized")
-
# Create MemCube with pre-initialized memory instances
naive_mem_cube = NaiveMemCube(
text_mem=text_mem,
@@ -304,7 +294,6 @@ def init_server() -> dict[str, Any]:
"internet_retriever": internet_retriever,
"memory_manager": memory_manager,
"default_cube_config": default_cube_config,
- "mos_server": mos_server,
"mem_scheduler": mem_scheduler,
"naive_mem_cube": naive_mem_cube,
"searcher": searcher,
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
index 5655bf1e5..d29429fc9 100644
--- a/src/memos/api/handlers/config_builders.py
+++ b/src/memos/api/handlers/config_builders.py
@@ -39,13 +39,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/src/memos/api/product_api.py b/src/memos/api/product_api.py
deleted file mode 100644
index ec5cccae1..000000000
--- a/src/memos/api/product_api.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import logging
-
-from fastapi import FastAPI
-
-from memos.api.exceptions import APIExceptionHandler
-from memos.api.middleware.request_context import RequestContextMiddleware
-from memos.api.routers.product_router import router as product_router
-
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
-logger = logging.getLogger(__name__)
-
-app = FastAPI(
- title="MemOS Product REST APIs",
- description="A REST API for managing multiple users with MemOS Product.",
- version="1.0.1",
-)
-
-app.add_middleware(RequestContextMiddleware, source="product_api")
-# Include routers
-app.include_router(product_router)
-
-# Exception handlers
-app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler)
-app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler)
-
-
-if __name__ == "__main__":
- import argparse
-
- import uvicorn
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8001)
- parser.add_argument("--workers", type=int, default=1)
- args = parser.parse_args()
- uvicorn.run("memos.api.product_api:app", host="0.0.0.0", port=args.port, workers=args.workers)
diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py
deleted file mode 100644
index 609d61124..000000000
--- a/src/memos/api/routers/product_router.py
+++ /dev/null
@@ -1,477 +0,0 @@
-import json
-import time
-import traceback
-
-from fastapi import APIRouter, HTTPException
-from fastapi.responses import StreamingResponse
-
-from memos.api.config import APIConfig
-from memos.api.product_models import (
- BaseResponse,
- ChatCompleteRequest,
- ChatRequest,
- GetMemoryPlaygroundRequest,
- MemoryCreateRequest,
- MemoryResponse,
- SearchRequest,
- SearchResponse,
- SimpleResponse,
- SuggestionRequest,
- SuggestionResponse,
- UserRegisterRequest,
- UserRegisterResponse,
-)
-from memos.configs.mem_os import MOSConfig
-from memos.log import get_logger
-from memos.mem_os.product import MOSProduct
-from memos.memos_tools.notification_service import get_error_bot_function, get_online_bot_function
-
-
-logger = get_logger(__name__)
-
-router = APIRouter(prefix="/product", tags=["Product API"])
-
-# Initialize MOSProduct instance with lazy initialization
-MOS_PRODUCT_INSTANCE = None
-
-
-def get_mos_product_instance():
- """Get or create MOSProduct instance."""
- global MOS_PRODUCT_INSTANCE
- if MOS_PRODUCT_INSTANCE is None:
- default_config = APIConfig.get_product_default_config()
- logger.info(f"*********init_default_mos_config********* {default_config}")
- from memos.configs.mem_os import MOSConfig
-
- mos_config = MOSConfig(**default_config)
-
- # Get default cube config from APIConfig (may be None if disabled)
- default_cube_config = APIConfig.get_default_cube_config()
- logger.info(f"*********initdefault_cube_config******** {default_cube_config}")
-
- # Get DingDing bot functions
- dingding_enabled = APIConfig.is_dingding_bot_enabled()
- online_bot = get_online_bot_function() if dingding_enabled else None
- error_bot = get_error_bot_function() if dingding_enabled else None
-
- MOS_PRODUCT_INSTANCE = MOSProduct(
- default_config=mos_config,
- default_cube_config=default_cube_config,
- online_bot=online_bot,
- error_bot=error_bot,
- )
- logger.info("MOSProduct instance created successfully with inheritance architecture")
- return MOS_PRODUCT_INSTANCE
-
-
-get_mos_product_instance()
-
-
-@router.post("/configure", summary="Configure MOSProduct", response_model=SimpleResponse)
-def set_config(config):
- """Set MOSProduct configuration."""
- global MOS_PRODUCT_INSTANCE
- MOS_PRODUCT_INSTANCE = MOSProduct(default_config=config)
- return SimpleResponse(message="Configuration set successfully")
-
-
-@router.post("/users/register", summary="Register a new user", response_model=UserRegisterResponse)
-def register_user(user_req: UserRegisterRequest):
- """Register a new user with configuration and default cube."""
- try:
- # Get configuration for the user
- time_start_register = time.time()
- user_config, default_mem_cube = APIConfig.create_user_config(
- user_name=user_req.user_id, user_id=user_req.user_id
- )
- logger.info(f"user_config: {user_config.model_dump(mode='json')}")
- logger.info(f"default_mem_cube: {default_mem_cube.config.model_dump(mode='json')}")
- logger.info(
- f"time register api : create user config time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
- )
- mos_product = get_mos_product_instance()
-
- # Register user with default config and mem cube
- result = mos_product.user_register(
- user_id=user_req.user_id,
- user_name=user_req.user_name,
- interests=user_req.interests,
- config=user_config,
- default_mem_cube=default_mem_cube,
- mem_cube_id=user_req.mem_cube_id,
- )
- logger.info(
- f"time register api : register time user_id: {user_req.user_id} time is: {time.time() - time_start_register}"
- )
- if result["status"] == "success":
- return UserRegisterResponse(
- message="User registered successfully",
- data={"user_id": result["user_id"], "mem_cube_id": result["default_cube_id"]},
- )
- else:
- raise HTTPException(status_code=400, detail=result["message"])
-
- except Exception as err:
- logger.error(f"Failed to register user: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/suggestions/{user_id}", summary="Get suggestion queries", response_model=SuggestionResponse
-)
-def get_suggestion_queries(user_id: str):
- """Get suggestion queries for a specific user."""
- try:
- mos_product = get_mos_product_instance()
- suggestions = mos_product.get_suggestion_query(user_id)
- return SuggestionResponse(
- message="Suggestions retrieved successfully", data={"query": suggestions}
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post(
- "/suggestions",
- summary="Get suggestion queries with language",
- response_model=SuggestionResponse,
-)
-def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
- """Get suggestion queries for a specific user with language preference."""
- try:
- mos_product = get_mos_product_instance()
- suggestions = mos_product.get_suggestion_query(
- user_id=suggestion_req.user_id,
- language=suggestion_req.language,
- message=suggestion_req.message,
- )
- return SuggestionResponse(
- message="Suggestions retrieved successfully", data={"query": suggestions}
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get suggestions: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
-def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
- """Get all memories for a specific user."""
- try:
- mos_product = get_mos_product_instance()
- if memory_req.search_query:
- result = mos_product.get_subgraph(
- user_id=memory_req.user_id,
- query=memory_req.search_query,
- mem_cube_ids=memory_req.mem_cube_ids,
- )
- return MemoryResponse(message="Memories retrieved successfully", data=result)
- else:
- result = mos_product.get_all(
- user_id=memory_req.user_id,
- memory_type=memory_req.memory_type,
- mem_cube_ids=memory_req.mem_cube_ids,
- )
- return MemoryResponse(message="Memories retrieved successfully", data=result)
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get memories: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/add", summary="add a new memory", response_model=SimpleResponse)
-def create_memory(memory_req: MemoryCreateRequest):
- """Create a new memory for a specific user."""
- logger.info("DIAGNOSTIC: /product/add endpoint called. This confirms the new code is deployed.")
- # Initialize status_tracker outside try block to avoid NameError in except blocks
- status_tracker = None
-
- try:
- time_start_add = time.time()
- mos_product = get_mos_product_instance()
-
- # Track task if task_id is provided
- item_id: str | None = None
- if (
- memory_req.task_id
- and hasattr(mos_product, "mem_scheduler")
- and mos_product.mem_scheduler
- ):
- from uuid import uuid4
-
- from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
-
- item_id = str(uuid4()) # Generate a unique item_id for this submission
-
- # Get Redis client from scheduler
- if (
- hasattr(mos_product.mem_scheduler, "redis_client")
- and mos_product.mem_scheduler.redis_client
- ):
- status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client)
- # Submit task with "product_add" type
- status_tracker.task_submitted(
- task_id=item_id, # Use generated item_id for internal tracking
- user_id=memory_req.user_id,
- task_type="product_add",
- mem_cube_id=memory_req.mem_cube_id or memory_req.user_id,
- business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id
- )
- status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here
-
- # Execute the add operation
- mos_product.add(
- user_id=memory_req.user_id,
- memory_content=memory_req.memory_content,
- messages=memory_req.messages,
- doc_path=memory_req.doc_path,
- mem_cube_id=memory_req.mem_cube_id,
- source=memory_req.source,
- user_profile=memory_req.user_profile,
- session_id=memory_req.session_id,
- task_id=memory_req.task_id,
- )
-
- # Mark task as completed
- if status_tracker and item_id:
- status_tracker.task_completed(item_id, memory_req.user_id)
-
- logger.info(
- f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
- )
- return SimpleResponse(message="Memory created successfully")
-
- except ValueError as err:
- # Mark task as failed if tracking
- if status_tracker and item_id:
- status_tracker.task_failed(item_id, memory_req.user_id, str(err))
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- # Mark task as failed if tracking
- if status_tracker and item_id:
- status_tracker.task_failed(item_id, memory_req.user_id, str(err))
- logger.error(f"Failed to create memory: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/search", summary="Search memories", response_model=SearchResponse)
-def search_memories(search_req: SearchRequest):
- """Search memories for a specific user."""
- try:
- time_start_search = time.time()
- mos_product = get_mos_product_instance()
- result = mos_product.search(
- query=search_req.query,
- user_id=search_req.user_id,
- install_cube_ids=[search_req.mem_cube_id] if search_req.mem_cube_id else None,
- top_k=search_req.top_k,
- session_id=search_req.session_id,
- )
- logger.info(
- f"time search api : add time user_id: {search_req.user_id} time is: {time.time() - time_start_search}"
- )
- return SearchResponse(message="Search completed successfully", data=result)
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to search memories: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/chat", summary="Chat with MemOS")
-def chat(chat_req: ChatRequest):
- """Chat with MemOS for a specific user. Returns SSE stream."""
- try:
- mos_product = get_mos_product_instance()
-
- def generate_chat_response():
- """Generate chat response as SSE stream."""
- try:
- # Directly yield from the generator without async wrapper
- yield from mos_product.chat_with_references(
- query=chat_req.query,
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- history=chat_req.history,
- internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
- session_id=chat_req.session_id,
- )
-
- except Exception as e:
- logger.error(f"Error in chat stream: {e}")
- error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
- yield error_data
-
- return StreamingResponse(
- generate_chat_response(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Content-Type": "text/event-stream",
- "Access-Control-Allow-Origin": "*",
- "Access-Control-Allow-Headers": "*",
- "Access-Control-Allow-Methods": "*",
- },
- )
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to start chat: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
-def chat_complete(chat_req: ChatCompleteRequest):
- """Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
- try:
- mos_product = get_mos_product_instance()
-
- # Collect all responses from the generator
- content, references = mos_product.chat(
- query=chat_req.query,
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- history=chat_req.history,
- internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
- base_prompt=chat_req.base_prompt or chat_req.system_prompt,
- # will deprecate base_prompt in the future
- top_k=chat_req.top_k,
- threshold=chat_req.threshold,
- session_id=chat_req.session_id,
- )
-
- # Return the complete response
- return {
- "message": "Chat completed successfully",
- "data": {"response": content, "references": references},
- }
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to start chat: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/users", summary="List all users", response_model=BaseResponse[list])
-def list_users():
- """List all registered users."""
- try:
- mos_product = get_mos_product_instance()
- users = mos_product.list_users()
- return BaseResponse(message="Users retrieved successfully", data=users)
- except Exception as err:
- logger.error(f"Failed to list users: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/users/{user_id}", summary="Get user info", response_model=BaseResponse[dict])
-async def get_user_info(user_id: str):
- """Get user information including accessible cubes."""
- try:
- mos_product = get_mos_product_instance()
- user_info = mos_product.get_user_info(user_id)
- return BaseResponse(message="User info retrieved successfully", data=user_info)
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get user info: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/configure/{user_id}", summary="Get MOSProduct configuration", response_model=SimpleResponse
-)
-def get_config(user_id: str):
- """Get MOSProduct configuration."""
- global MOS_PRODUCT_INSTANCE
- config = MOS_PRODUCT_INSTANCE.default_config
- return SimpleResponse(message="Configuration retrieved successfully", data=config)
-
-
-@router.get(
- "/users/{user_id}/config", summary="Get user configuration", response_model=BaseResponse[dict]
-)
-def get_user_config(user_id: str):
- """Get user-specific configuration."""
- try:
- mos_product = get_mos_product_instance()
- config = mos_product.get_user_config(user_id)
- if config:
- return BaseResponse(
- message="User configuration retrieved successfully",
- data=config.model_dump(mode="json"),
- )
- else:
- raise HTTPException(
- status_code=404, detail=f"Configuration not found for user {user_id}"
- )
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to get user config: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.put(
- "/users/{user_id}/config", summary="Update user configuration", response_model=SimpleResponse
-)
-def update_user_config(user_id: str, config_data: dict):
- """Update user-specific configuration."""
- try:
- mos_product = get_mos_product_instance()
-
- # Create MOSConfig from the provided data
- config = MOSConfig(**config_data)
-
- # Update the configuration
- success = mos_product.update_user_config(user_id, config)
- if success:
- return SimpleResponse(message="User configuration updated successfully")
- else:
- raise HTTPException(status_code=500, detail="Failed to update user configuration")
-
- except ValueError as err:
- raise HTTPException(status_code=400, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to update user config: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get(
- "/instances/status", summary="Get user configuration status", response_model=BaseResponse[dict]
-)
-def get_instance_status():
- """Get information about active user configurations in memory."""
- try:
- mos_product = get_mos_product_instance()
- status_info = mos_product.get_user_instance_info()
- return BaseResponse(
- message="User configuration status retrieved successfully", data=status_info
- )
- except Exception as err:
- logger.error(f"Failed to get user configuration status: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
-
-
-@router.get("/instances/count", summary="Get active user count", response_model=BaseResponse[int])
-def get_active_user_count():
- """Get the number of active user configurations in memory."""
- try:
- mos_product = get_mos_product_instance()
- count = mos_product.get_active_user_count()
- return BaseResponse(message="Active user count retrieved successfully", data=count)
- except Exception as err:
- logger.error(f"Failed to get active user count: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
diff --git a/src/memos/api/start_api.py b/src/memos/api/start_api.py
deleted file mode 100644
index 24a36f017..000000000
--- a/src/memos/api/start_api.py
+++ /dev/null
@@ -1,433 +0,0 @@
-import logging
-import os
-
-from typing import Any, Generic, TypeVar
-
-from dotenv import load_dotenv
-from fastapi import FastAPI
-from fastapi.requests import Request
-from fastapi.responses import JSONResponse, RedirectResponse
-from pydantic import BaseModel, Field
-
-from memos.api.middleware.request_context import RequestContextMiddleware
-from memos.configs.mem_os import MOSConfig
-from memos.mem_os.main import MOS
-from memos.mem_user.user_manager import UserManager, UserRole
-
-
-# Configure logging
-logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
-logger = logging.getLogger(__name__)
-
-# Load environment variables
-load_dotenv(override=True)
-
-T = TypeVar("T")
-
-# Default configuration
-DEFAULT_CONFIG = {
- "user_id": os.getenv("MOS_USER_ID", "default_user"),
- "session_id": os.getenv("MOS_SESSION_ID", "default_session"),
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "top_k": int(os.getenv("MOS_TOP_K", "5")),
- "chat_model": {
- "backend": os.getenv("MOS_CHAT_MODEL_PROVIDER", "openai"),
- "config": {
- "model_name_or_path": os.getenv("MOS_CHAT_MODEL", "gpt-3.5-turbo"),
- "api_key": os.getenv("OPENAI_API_KEY", "apikey"),
- "temperature": float(os.getenv("MOS_CHAT_TEMPERATURE", "0.7")),
- "api_base": os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
- },
- },
-}
-
-# Initialize MOS instance with lazy initialization
-MOS_INSTANCE = None
-
-
-def get_mos_instance():
- """Get or create MOS instance with default user creation."""
- global MOS_INSTANCE
- if MOS_INSTANCE is None:
- # Create a temporary MOS instance to access user manager
- temp_config = MOSConfig(**DEFAULT_CONFIG)
- temp_mos = MOS.__new__(MOS)
- temp_mos.config = temp_config
- temp_mos.user_id = temp_config.user_id
- temp_mos.session_id = temp_config.session_id
- temp_mos.mem_cubes = {}
- temp_mos.chat_llm = None # Will be initialized later
- temp_mos.user_manager = UserManager()
-
- # Create default user if it doesn't exist
- if not temp_mos.user_manager.validate_user(temp_config.user_id):
- temp_mos.user_manager.create_user(
- user_name=temp_config.user_id, role=UserRole.USER, user_id=temp_config.user_id
- )
- logger.info(f"Created default user: {temp_config.user_id}")
-
- # Now create the actual MOS instance
- MOS_INSTANCE = MOS(config=temp_config)
-
- return MOS_INSTANCE
-
-
-app = FastAPI(
- title="MemOS REST APIs",
- description="A REST API for managing and searching memories using MemOS.",
- version="1.0.0",
-)
-
-app.add_middleware(RequestContextMiddleware)
-
-
-class BaseRequest(BaseModel):
- """Base model for all requests."""
-
- user_id: str | None = Field(
- None, description="User ID for the request", json_schema_extra={"example": "user123"}
- )
-
-
-class BaseResponse(BaseModel, Generic[T]):
- """Base model for all responses."""
-
- code: int = Field(200, description="Response status code", json_schema_extra={"example": 200})
- message: str = Field(
- ..., description="Response message", json_schema_extra={"example": "Operation successful"}
- )
- data: T | None = Field(None, description="Response data")
-
-
-class Message(BaseModel):
- role: str = Field(
- ...,
- description="Role of the message (user or assistant).",
- json_schema_extra={"example": "user"},
- )
- content: str = Field(
- ...,
- description="Message content.",
- json_schema_extra={"example": "Hello, how can I help you?"},
- )
-
-
-class MemoryCreate(BaseRequest):
- messages: list[Message] | None = Field(
- None,
- description="List of messages to store.",
- json_schema_extra={"example": [{"role": "user", "content": "Hello"}]},
- )
- mem_cube_id: str | None = Field(
- None, description="ID of the memory cube", json_schema_extra={"example": "cube123"}
- )
- memory_content: str | None = Field(
- None,
- description="Content to store as memory",
- json_schema_extra={"example": "This is a memory content"},
- )
- doc_path: str | None = Field(
- None,
- description="Path to document to store",
- json_schema_extra={"example": "/path/to/document.txt"},
- )
-
-
-class SearchRequest(BaseRequest):
- query: str = Field(
- ...,
- description="Search query.",
- json_schema_extra={"example": "How to implement a feature?"},
- )
- install_cube_ids: list[str] | None = Field(
- None,
- description="List of cube IDs to search in",
- json_schema_extra={"example": ["cube123", "cube456"]},
- )
-
-
-class MemCubeRegister(BaseRequest):
- mem_cube_name_or_path: str = Field(
- ...,
- description="Name or path of the MemCube to register.",
- json_schema_extra={"example": "/path/to/cube"},
- )
- mem_cube_id: str | None = Field(
- None, description="ID for the MemCube", json_schema_extra={"example": "cube123"}
- )
-
-
-class ChatRequest(BaseRequest):
- query: str = Field(
- ...,
- description="Chat query message.",
- json_schema_extra={"example": "What is the latest update?"},
- )
-
-
-class UserCreate(BaseRequest):
- user_name: str | None = Field(
- None, description="Name of the user", json_schema_extra={"example": "john_doe"}
- )
- role: str = Field("user", description="Role of the user", json_schema_extra={"example": "user"})
- user_id: str = Field(..., description="User ID", json_schema_extra={"example": "user123"})
-
-
-class CubeShare(BaseRequest):
- target_user_id: str = Field(
- ..., description="Target user ID to share with", json_schema_extra={"example": "user456"}
- )
-
-
-class SimpleResponse(BaseResponse[None]):
- """Simple response model for operations without data return."""
-
-
-class ConfigResponse(BaseResponse[None]):
- """Response model for configuration endpoint."""
-
-
-class MemoryResponse(BaseResponse[dict]):
- """Response model for memory operations."""
-
-
-class SearchResponse(BaseResponse[dict]):
- """Response model for search operations."""
-
-
-class ChatResponse(BaseResponse[str]):
- """Response model for chat operations."""
-
-
-class UserResponse(BaseResponse[dict]):
- """Response model for user operations."""
-
-
-class UserListResponse(BaseResponse[list]):
- """Response model for user list operations."""
-
-
-@app.post("/configure", summary="Configure MemOS", response_model=ConfigResponse)
-async def set_config(config: MOSConfig):
- """Set MemOS configuration."""
- global MOS_INSTANCE
-
- # Create a temporary user manager to check/create default user
- temp_user_manager = UserManager()
-
- # Create default user if it doesn't exist
- if not temp_user_manager.validate_user(config.user_id):
- temp_user_manager.create_user(
- user_name=config.user_id, role=UserRole.USER, user_id=config.user_id
- )
- logger.info(f"Created default user: {config.user_id}")
-
- # Now create the MOS instance
- MOS_INSTANCE = MOS(config=config)
- return ConfigResponse(message="Configuration set successfully")
-
-
-@app.post("/users", summary="Create a new user", response_model=UserResponse)
-async def create_user(user_create: UserCreate):
- """Create a new user."""
- mos_instance = get_mos_instance()
- role = UserRole(user_create.role)
- user_id = mos_instance.create_user(
- user_id=user_create.user_id, role=role, user_name=user_create.user_name
- )
- return UserResponse(message="User created successfully", data={"user_id": user_id})
-
-
-@app.get("/users", summary="List all users", response_model=UserListResponse)
-async def list_users():
- """List all active users."""
- mos_instance = get_mos_instance()
- users = mos_instance.list_users()
- return UserListResponse(message="Users retrieved successfully", data=users)
-
-
-@app.get("/users/me", summary="Get current user info", response_model=UserResponse)
-async def get_user_info():
- """Get current user information including accessible cubes."""
- mos_instance = get_mos_instance()
- user_info = mos_instance.get_user_info()
- return UserResponse(message="User info retrieved successfully", data=user_info)
-
-
-@app.post("/mem_cubes", summary="Register a MemCube", response_model=SimpleResponse)
-async def register_mem_cube(mem_cube: MemCubeRegister):
- """Register a new MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.register_mem_cube(
- mem_cube_name_or_path=mem_cube.mem_cube_name_or_path,
- mem_cube_id=mem_cube.mem_cube_id,
- user_id=mem_cube.user_id,
- )
- return SimpleResponse(message="MemCube registered successfully")
-
-
-@app.delete(
- "/mem_cubes/{mem_cube_id}", summary="Unregister a MemCube", response_model=SimpleResponse
-)
-async def unregister_mem_cube(mem_cube_id: str, user_id: str | None = None):
- """Unregister a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.unregister_mem_cube(mem_cube_id=mem_cube_id, user_id=user_id)
- return SimpleResponse(message="MemCube unregistered successfully")
-
-
-@app.post(
- "/mem_cubes/{cube_id}/share",
- summary="Share a cube with another user",
- response_model=SimpleResponse,
-)
-async def share_cube(cube_id: str, share_request: CubeShare):
- """Share a cube with another user."""
- mos_instance = get_mos_instance()
- success = mos_instance.share_cube_with_user(cube_id, share_request.target_user_id)
- if success:
- return SimpleResponse(message="Cube shared successfully")
- else:
- raise ValueError("Failed to share cube")
-
-
-@app.post("/memories", summary="Create memories", response_model=SimpleResponse)
-async def add_memory(memory_create: MemoryCreate):
- """Store new memories in a MemCube."""
- if not any([memory_create.messages, memory_create.memory_content, memory_create.doc_path]):
- raise ValueError("Either messages, memory_content, or doc_path must be provided")
- mos_instance = get_mos_instance()
- if memory_create.messages:
- messages = [m.model_dump() for m in memory_create.messages]
- mos_instance.add(
- messages=messages,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- elif memory_create.memory_content:
- mos_instance.add(
- memory_content=memory_create.memory_content,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- elif memory_create.doc_path:
- mos_instance.add(
- doc_path=memory_create.doc_path,
- mem_cube_id=memory_create.mem_cube_id,
- user_id=memory_create.user_id,
- )
- return SimpleResponse(message="Memories added successfully")
-
-
-@app.get("/memories", summary="Get all memories", response_model=MemoryResponse)
-async def get_all_memories(
- mem_cube_id: str | None = None,
- user_id: str | None = None,
-):
- """Retrieve all memories from a MemCube."""
- mos_instance = get_mos_instance()
- result = mos_instance.get_all(mem_cube_id=mem_cube_id, user_id=user_id)
- return MemoryResponse(message="Memories retrieved successfully", data=result)
-
-
-@app.get(
- "/memories/{mem_cube_id}/{memory_id}", summary="Get a memory", response_model=MemoryResponse
-)
-async def get_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None):
- """Retrieve a specific memory by ID from a MemCube."""
- mos_instance = get_mos_instance()
- result = mos_instance.get(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id)
- return MemoryResponse(message="Memory retrieved successfully", data=result)
-
-
-@app.post("/search", summary="Search memories", response_model=SearchResponse)
-async def search_memories(search_req: SearchRequest):
- """Search for memories across MemCubes."""
- mos_instance = get_mos_instance()
- result = mos_instance.search(
- query=search_req.query,
- user_id=search_req.user_id,
- install_cube_ids=search_req.install_cube_ids,
- )
- return SearchResponse(message="Search completed successfully", data=result)
-
-
-@app.put(
- "/memories/{mem_cube_id}/{memory_id}", summary="Update a memory", response_model=SimpleResponse
-)
-async def update_memory(
- mem_cube_id: str, memory_id: str, updated_memory: dict[str, Any], user_id: str | None = None
-):
- """Update an existing memory in a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.update(
- mem_cube_id=mem_cube_id,
- memory_id=memory_id,
- text_memory_item=updated_memory,
- user_id=user_id,
- )
- return SimpleResponse(message="Memory updated successfully")
-
-
-@app.delete(
- "/memories/{mem_cube_id}/{memory_id}", summary="Delete a memory", response_model=SimpleResponse
-)
-async def delete_memory(mem_cube_id: str, memory_id: str, user_id: str | None = None):
- """Delete a specific memory from a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.delete(mem_cube_id=mem_cube_id, memory_id=memory_id, user_id=user_id)
- return SimpleResponse(message="Memory deleted successfully")
-
-
-@app.delete("/memories/{mem_cube_id}", summary="Delete all memories", response_model=SimpleResponse)
-async def delete_all_memories(mem_cube_id: str, user_id: str | None = None):
- """Delete all memories from a MemCube."""
- mos_instance = get_mos_instance()
- mos_instance.delete_all(mem_cube_id=mem_cube_id, user_id=user_id)
- return SimpleResponse(message="All memories deleted successfully")
-
-
-@app.post("/chat", summary="Chat with MemOS", response_model=ChatResponse)
-async def chat(chat_req: ChatRequest):
- """Chat with the MemOS system."""
- mos_instance = get_mos_instance()
- response = mos_instance.chat(query=chat_req.query, user_id=chat_req.user_id)
- if response is None:
- raise ValueError("No response generated")
- return ChatResponse(message="Chat response generated", data=response)
-
-
-@app.get("/", summary="Redirect to the OpenAPI documentation", include_in_schema=False)
-async def home():
- """Redirect to the OpenAPI documentation."""
- return RedirectResponse(url="/docs", status_code=307)
-
-
-@app.exception_handler(ValueError)
-async def value_error_handler(request: Request, exc: ValueError):
- """Handle ValueError exceptions globally."""
- return JSONResponse(
- status_code=400,
- content={"code": 400, "message": str(exc), "data": None},
- )
-
-
-@app.exception_handler(Exception)
-async def global_exception_handler(request: Request, exc: Exception):
- """Handle all unhandled exceptions globally."""
- logger.exception("Unhandled error:")
- return JSONResponse(
- status_code=500,
- content={"code": 500, "message": str(exc), "data": None},
- )
-
-
-if __name__ == "__main__":
- import argparse
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, default=8000, help="Port to run the server on")
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on")
- parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
- args = parser.parse_args()
diff --git a/src/memos/cli.py b/src/memos/cli.py
index 092f2d276..2ead5ab29 100644
--- a/src/memos/cli.py
+++ b/src/memos/cli.py
@@ -11,9 +11,16 @@
from io import BytesIO
+def get_openapi_app():
+ """Return the FastAPI app used for OpenAPI export."""
+ from memos.api.server_api import app
+
+ return app
+
+
def export_openapi(output: str) -> bool:
"""Export OpenAPI schema to JSON file."""
- from memos.api.server_api import app
+ app = get_openapi_app()
# Create directory if it doesn't exist
if os.path.dirname(output):
diff --git a/src/memos/configs/graph_db.py b/src/memos/configs/graph_db.py
index 5900d2357..98de09812 100644
--- a/src/memos/configs/graph_db.py
+++ b/src/memos/configs/graph_db.py
@@ -103,57 +103,6 @@ def validate_community(self):
return self
-class NebulaGraphDBConfig(BaseGraphDBConfig):
- """
- NebulaGraph-specific configuration.
-
- Key concepts:
- - `space`: Equivalent to a database or namespace. All tag/edge/schema live within a space.
- - `user_name`: Used for logical tenant isolation if needed.
- - `auto_create`: Whether to automatically create the target space if it does not exist.
-
- Example:
- ---
- hosts = ["127.0.0.1:9669"]
- user = "root"
- password = "nebula"
- space = "shared_graph"
- user_name = "alice"
- """
-
- space: str = Field(
- ..., description="The name of the target NebulaGraph space (like a database)"
- )
- user_name: str | None = Field(
- default=None,
- description="Logical user or tenant ID for data isolation (optional, used in metadata tagging)",
- )
- auto_create: bool = Field(
- default=False,
- description="Whether to auto-create the space if it does not exist",
- )
- use_multi_db: bool = Field(
- default=True,
- description=(
- "If True: use Neo4j's multi-database feature for physical isolation; "
- "each user typically gets a separate database. "
- "If False: use a single shared database with logical isolation by user_name."
- ),
- )
- max_client: int = Field(
- default=1000,
- description=("max_client"),
- )
- embedding_dimension: int = Field(default=3072, description="Dimension of vector embedding")
-
- @model_validator(mode="after")
- def validate_config(self):
- """Validate config."""
- if not self.space:
- raise ValueError("`space` must be provided")
- return self
-
-
class PolarDBGraphDBConfig(BaseConfig):
"""
PolarDB-specific configuration.
@@ -299,7 +248,6 @@ class GraphDBConfigFactory(BaseModel):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDBConfig,
"neo4j-community": Neo4jCommunityGraphDBConfig,
- "nebular": NebulaGraphDBConfig,
"polardb": PolarDBGraphDBConfig,
"postgres": PostgresGraphDBConfig,
}
diff --git a/src/memos/context/context.py b/src/memos/context/context.py
index 5c8401732..5347de880 100644
--- a/src/memos/context/context.py
+++ b/src/memos/context/context.py
@@ -155,7 +155,7 @@ def get_current_user_name() -> str | None:
def get_current_source() -> str | None:
"""
- Get the current request's source (e.g., 'product_api' or 'server_api').
+ Get the current request's source (for example, 'server_api').
"""
context = _request_context.get()
if context:
diff --git a/src/memos/graph_dbs/factory.py b/src/memos/graph_dbs/factory.py
index c207e3190..93b5971ec 100644
--- a/src/memos/graph_dbs/factory.py
+++ b/src/memos/graph_dbs/factory.py
@@ -2,7 +2,6 @@
from memos.configs.graph_db import GraphDBConfigFactory
from memos.graph_dbs.base import BaseGraphDB
-from memos.graph_dbs.nebular import NebulaGraphDB
from memos.graph_dbs.neo4j import Neo4jGraphDB
from memos.graph_dbs.neo4j_community import Neo4jCommunityGraphDB
from memos.graph_dbs.polardb import PolarDBGraphDB
@@ -15,7 +14,6 @@ class GraphStoreFactory(BaseGraphDB):
backend_to_class: ClassVar[dict[str, Any]] = {
"neo4j": Neo4jGraphDB,
"neo4j-community": Neo4jCommunityGraphDB,
- "nebular": NebulaGraphDB,
"polardb": PolarDBGraphDB,
"postgres": PostgresGraphDB,
}
diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py
deleted file mode 100644
index 428d6d09e..000000000
--- a/src/memos/graph_dbs/nebular.py
+++ /dev/null
@@ -1,1794 +0,0 @@
-import json
-import traceback
-
-from contextlib import suppress
-from datetime import datetime
-from threading import Lock
-from typing import TYPE_CHECKING, Any, ClassVar, Literal
-
-import numpy as np
-
-from memos.configs.graph_db import NebulaGraphDBConfig
-from memos.dependency import require_python_package
-from memos.graph_dbs.base import BaseGraphDB
-from memos.log import get_logger
-from memos.utils import timed
-
-
-if TYPE_CHECKING:
- from nebulagraph_python import (
- NebulaClient,
- )
-
-
-logger = get_logger(__name__)
-
-
-_TRANSIENT_ERR_KEYS = (
- "Session not found",
- "Connection not established",
- "timeout",
- "deadline exceeded",
- "Broken pipe",
- "EOFError",
- "socket closed",
- "connection reset",
- "connection refused",
-)
-
-
-@timed
-def _normalize(vec: list[float]) -> list[float]:
- v = np.asarray(vec, dtype=np.float32)
- norm = np.linalg.norm(v)
- return (v / (norm if norm else 1.0)).tolist()
-
-
-@timed
-def _compose_node(item: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
- node_id = item["id"]
- memory = item["memory"]
- metadata = item.get("metadata", {})
- return node_id, memory, metadata
-
-
-@timed
-def _escape_str(value: str) -> str:
- out = []
- for ch in value:
- code = ord(ch)
- if ch == "\\":
- out.append("\\\\")
- elif ch == '"':
- out.append('\\"')
- elif ch == "\n":
- out.append("\\n")
- elif ch == "\r":
- out.append("\\r")
- elif ch == "\t":
- out.append("\\t")
- elif ch == "\b":
- out.append("\\b")
- elif ch == "\f":
- out.append("\\f")
- elif code < 0x20 or code in (0x2028, 0x2029):
- out.append(f"\\u{code:04x}")
- else:
- out.append(ch)
- return "".join(out)
-
-
-@timed
-def _format_datetime(value: str | datetime) -> str:
- """Ensure datetime is in ISO 8601 format string."""
- if isinstance(value, datetime):
- return value.isoformat()
- return str(value)
-
-
-@timed
-def _normalize_datetime(val):
- """
- Normalize datetime to ISO 8601 UTC string with +00:00.
- - If val is datetime object -> keep isoformat() (Neo4j)
- - If val is string without timezone -> append +00:00 (Nebula)
- - Otherwise just str()
- """
- if hasattr(val, "isoformat"):
- return val.isoformat()
- if isinstance(val, str) and not val.endswith(("+00:00", "Z", "+08:00")):
- return val + "+08:00"
- return str(val)
-
-
-class NebulaGraphDB(BaseGraphDB):
- """
- NebulaGraph-based implementation of a graph memory store.
- """
-
- # ====== shared pool cache & refcount ======
- # These are process-local; in a multi-process model each process will
- # have its own cache.
- _CLIENT_CACHE: ClassVar[dict[str, "NebulaClient"]] = {}
- _CLIENT_REFCOUNT: ClassVar[dict[str, int]] = {}
- _CLIENT_LOCK: ClassVar[Lock] = Lock()
- _CLIENT_INIT_DONE: ClassVar[set[str]] = set()
-
- @staticmethod
- def _get_hosts_from_cfg(cfg: NebulaGraphDBConfig) -> list[str]:
- hosts = getattr(cfg, "uri", None) or getattr(cfg, "hosts", None)
- if isinstance(hosts, str):
- return [hosts]
- return list(hosts or [])
-
- @staticmethod
- def _make_client_key(cfg: NebulaGraphDBConfig) -> str:
- hosts = NebulaGraphDB._get_hosts_from_cfg(cfg)
- return "|".join(
- [
- "nebula-sync",
- ",".join(hosts),
- str(getattr(cfg, "user", "")),
- str(getattr(cfg, "space", "")),
- ]
- )
-
- @classmethod
- def _bootstrap_admin(cls, cfg: NebulaGraphDBConfig, client: "NebulaClient") -> "NebulaGraphDB":
- tmp = object.__new__(NebulaGraphDB)
- tmp.config = cfg
- tmp.db_name = cfg.space
- tmp.user_name = None
- tmp.embedding_dimension = getattr(cfg, "embedding_dimension", 3072)
- tmp.default_memory_dimension = 3072
- tmp.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- tmp.base_fields = set(tmp.common_fields) - {"usage"}
- tmp.heavy_fields = {"usage"}
- tmp.dim_field = (
- f"embedding_{tmp.embedding_dimension}"
- if str(tmp.embedding_dimension) != str(tmp.default_memory_dimension)
- else "embedding"
- )
- tmp.system_db_name = cfg.space
- tmp._client = client
- tmp._owns_client = False
- return tmp
-
- @classmethod
- def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "NebulaClient"]:
- from nebulagraph_python import (
- ConnectionConfig,
- NebulaClient,
- SessionConfig,
- SessionPoolConfig,
- )
-
- key = cls._make_client_key(cfg)
- with cls._CLIENT_LOCK:
- client = cls._CLIENT_CACHE.get(key)
- if client is None:
- # Connection setting
-
- tmp_client = NebulaClient(
- hosts=cfg.uri,
- username=cfg.user,
- password=cfg.password,
- session_config=SessionConfig(graph=None),
- session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
- )
- try:
- cls._ensure_space_exists(tmp_client, cfg)
- finally:
- tmp_client.close()
-
- conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
- if conn_conf is None:
- conn_conf = ConnectionConfig.from_defults(
- cls._get_hosts_from_cfg(cfg),
- getattr(cfg, "ssl_param", None),
- )
-
- sess_conf = SessionConfig(graph=getattr(cfg, "space", None))
- pool_conf = SessionPoolConfig(
- size=int(getattr(cfg, "max_client", 1000)), wait_timeout=5000
- )
-
- client = NebulaClient(
- hosts=conn_conf.hosts,
- username=cfg.user,
- password=cfg.password,
- conn_config=conn_conf,
- session_config=sess_conf,
- session_pool_config=pool_conf,
- )
- cls._CLIENT_CACHE[key] = client
- cls._CLIENT_REFCOUNT[key] = 0
- logger.info(f"[NebulaGraphDBSync] Created shared NebulaClient key={key}")
-
- cls._CLIENT_REFCOUNT[key] = cls._CLIENT_REFCOUNT.get(key, 0) + 1
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- try:
- pass
- finally:
- pass
-
- if getattr(cfg, "auto_create", False) and key not in cls._CLIENT_INIT_DONE:
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_INIT_DONE:
- admin = cls._bootstrap_admin(cfg, client)
- try:
- admin._ensure_database_exists()
- admin._create_basic_property_indexes()
- admin._create_vector_index(
- dimensions=int(
- admin.embedding_dimension or admin.default_memory_dimension
- ),
- )
- cls._CLIENT_INIT_DONE.add(key)
- logger.info("[NebulaGraphDBSync] One-time init done")
- except Exception:
- logger.exception("[NebulaGraphDBSync] One-time init failed")
-
- return key, client
-
- def _refresh_client(self):
- """
- refresh NebulaClient:
- """
- old_key = getattr(self, "_client_key", None)
- if not old_key:
- return
-
- cls = self.__class__
- with cls._CLIENT_LOCK:
- try:
- if old_key in cls._CLIENT_CACHE:
- try:
- cls._CLIENT_CACHE[old_key].close()
- except Exception as e:
- logger.warning(f"[refresh_client] close old client error: {e}")
- finally:
- cls._CLIENT_CACHE.pop(old_key, None)
- finally:
- cls._CLIENT_REFCOUNT[old_key] = 0
-
- new_key, new_client = cls._get_or_create_shared_client(self.config)
- self._client_key = new_key
- self._client = new_client
- logger.info(f"[NebulaGraphDBSync] client refreshed: {old_key} -> {new_key}")
-
- @classmethod
- def _release_shared_client(cls, key: str):
- with cls._CLIENT_LOCK:
- if key not in cls._CLIENT_CACHE:
- return
- cls._CLIENT_REFCOUNT[key] = max(0, cls._CLIENT_REFCOUNT.get(key, 0) - 1)
- if cls._CLIENT_REFCOUNT[key] == 0:
- try:
- cls._CLIENT_CACHE[key].close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client: {e}")
- finally:
- cls._CLIENT_CACHE.pop(key, None)
- cls._CLIENT_REFCOUNT.pop(key, None)
- logger.info(f"[NebulaGraphDBSync] Closed & removed client key={key}")
-
- @classmethod
- def close_all_shared_clients(cls):
- with cls._CLIENT_LOCK:
- for key, client in list(cls._CLIENT_CACHE.items()):
- try:
- client.close()
- except Exception as e:
- logger.warning(f"[NebulaGraphDBSync] Error closing client {key}: {e}")
- finally:
- logger.info(f"[NebulaGraphDBSync] Closed client key={key}")
- cls._CLIENT_CACHE.clear()
- cls._CLIENT_REFCOUNT.clear()
-
- @require_python_package(
- import_name="nebulagraph_python",
- install_command="pip install nebulagraph-python>=5.1.1",
- install_link=".....",
- )
- def __init__(self, config: NebulaGraphDBConfig):
- """
- NebulaGraph DB client initialization.
-
- Required config attributes:
- - hosts: list[str] like ["host1:port", "host2:port"]
- - user: str
- - password: str
- - db_name: str (optional for basic commands)
-
- Example config:
- {
- "hosts": ["xxx.xx.xx.xxx:xxxx"],
- "user": "root",
- "password": "nebula",
- "space": "test"
- }
- """
-
- assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
- self.config = config
- self.db_name = config.space
- self.user_name = config.user_name
- self.embedding_dimension = config.embedding_dimension
- self.default_memory_dimension = 3072
- self.common_fields = {
- "id",
- "memory",
- "user_name",
- "user_id",
- "session_id",
- "status",
- "key",
- "confidence",
- "tags",
- "created_at",
- "updated_at",
- "memory_type",
- "sources",
- "source",
- "node_type",
- "visibility",
- "usage",
- "background",
- }
- self.base_fields = set(self.common_fields) - {"usage"}
- self.heavy_fields = {"usage"}
- self.dim_field = (
- f"embedding_{self.embedding_dimension}"
- if (str(self.embedding_dimension) != str(self.default_memory_dimension))
- else "embedding"
- )
- self.system_db_name = config.space
-
- # ---- NEW: pool acquisition strategy
- # Get or create a shared pool from the class-level cache
- self._client_key, self._client = self._get_or_create_shared_client(config)
- self._owns_client = True
-
- logger.info("Connected to NebulaGraph successfully.")
-
- @timed
- def execute_query(self, gql: str, timeout: float = 60.0, auto_set_db: bool = True):
- def _wrap_use_db(q: str) -> str:
- if auto_set_db and self.db_name:
- return f"USE `{self.db_name}`\n{q}"
- return q
-
- try:
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
-
- except Exception as e:
- emsg = str(e)
- if any(k.lower() in emsg.lower() for k in _TRANSIENT_ERR_KEYS):
- logger.warning(f"[execute_query] {e!s} โ refreshing session pool and retry once...")
- try:
- self._refresh_client()
- return self._client.execute(_wrap_use_db(gql), timeout=timeout)
- except Exception:
- logger.exception("[execute_query] retry after refresh failed")
- raise
- raise
-
- @timed
- def close(self):
- """
- Close the connection resource if this instance owns it.
-
- - If pool was injected (`shared_pool`), do nothing.
- - If pool was acquired via shared cache, decrement refcount and close
- when the last owner releases it.
- """
- if not self._owns_client:
- logger.debug("[NebulaGraphDBSync] close() skipped (injected client).")
- return
- if self._client_key:
- self._release_shared_client(self._client_key)
- self._client_key = None
- self._client = None
-
- # NOTE: __del__ is best-effort; do not rely on GC order.
- def __del__(self):
- with suppress(Exception):
- self.close()
-
- @timed
- def create_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- # Create vector index
- self._create_vector_index(label, vector_property, dimensions, index_name)
- # Create indexes
- self._create_basic_property_indexes()
-
- @timed
- def remove_oldest_memory(
- self, memory_type: str, keep_latest: int, user_name: str | None = None
- ) -> None:
- """
- Remove all WorkingMemory nodes except the latest `keep_latest` entries.
-
- Args:
- memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
- keep_latest (int): Number of latest WorkingMemory entries to keep.
- user_name(str): optional user_name.
- """
- try:
- user_name = user_name if user_name else self.config.user_name
- optional_condition = f"AND n.user_name = '{user_name}'"
- count = self.count_nodes(memory_type, user_name)
- if count > keep_latest:
- delete_query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.memory_type = '{memory_type}'
- {optional_condition}
- ORDER BY n.updated_at DESC
- OFFSET {int(keep_latest)}
- DETACH DELETE n
- """
- self.execute_query(delete_query)
- except Exception as e:
- logger.warning(f"Delete old mem error: {e}")
-
- @timed
- def add_node(
- self, id: str, memory: str, metadata: dict[str, Any], user_name: str | None = None
- ) -> None:
- """
- Insert or update a Memory node in NebulaGraph.
- """
- metadata["user_name"] = user_name if user_name else self.config.user_name
- now = datetime.utcnow()
- metadata = metadata.copy()
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
- metadata["node_type"] = metadata.pop("type")
- metadata["id"] = id
- metadata["memory"] = memory
-
- if "embedding" in metadata and isinstance(metadata["embedding"], list):
- assert len(metadata["embedding"]) == self.embedding_dimension, (
- f"input embedding dimension must equal to {self.embedding_dimension}"
- )
- embedding = metadata.pop("embedding")
- metadata[self.dim_field] = _normalize(embedding)
-
- metadata = self._metadata_filter(metadata)
- properties = ", ".join(f"{k}: {self._format_value(v, k)}" for k, v in metadata.items())
- gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
-
- try:
- self.execute_query(gql)
- logger.info("insert success")
- except Exception as e:
- logger.error(
- f"Failed to insert vertex {id}: gql: {gql}, {e}\ntrace: {traceback.format_exc()}"
- )
-
- @timed
- def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- filter_clause = f'n.memory_type = "{scope}" AND n.user_name = "{user_name}"'
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {filter_clause}
- RETURN n.id AS id
- LIMIT 1
- """
-
- try:
- result = self.execute_query(query)
- return result.size == 0
- except Exception as e:
- logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True)
- raise
-
- @timed
- def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = None) -> None:
- """
- Update node fields in Nebular, auto-converting `created_at` and `updated_at` to datetime type if present.
- """
- user_name = user_name if user_name else self.config.user_name
- fields = fields.copy()
- set_clauses = []
- for k, v in fields.items():
- set_clauses.append(f"n.{k} = {self._format_value(v, k)}")
-
- set_clause_str = ",\n ".join(set_clauses)
-
- query = f"""
- MATCH (n@Memory {{id: "{id}"}})
- """
- query += f'WHERE n.user_name = "{user_name}"'
-
- query += f"\nSET {set_clause_str}"
- self.execute_query(query)
-
- @timed
- def delete_node(self, id: str, user_name: str | None = None) -> None:
- """
- Delete a node from the graph.
- Args:
- id: Node identifier to delete.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory {{id: "{id}"}}) WHERE n.user_name = {self._format_value(user_name)}
- DETACH DELETE n
- """
- self.execute_query(query)
-
- @timed
- def add_edge(self, source_id: str, target_id: str, type: str, user_name: str | None = None):
- """
- Create an edge from source node to target node.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type (e.g., 'RELATE_TO', 'PARENT').
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- if not source_id or not target_id:
- raise ValueError("[add_edge] source_id and target_id must be provided")
- user_name = user_name if user_name else self.config.user_name
- props = ""
- props = f'{{user_name: "{user_name}"}}'
- insert_stmt = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT (a) -[e@{type} {props}]-> (b)
- '''
- try:
- self.execute_query(insert_stmt)
- except Exception as e:
- logger.error(f"Failed to insert edge: {e}", exc_info=True)
-
- @timed
- def delete_edge(
- self, source_id: str, target_id: str, type: str, user_name: str | None = None
- ) -> None:
- """
- Delete a specific edge between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type to remove.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (a@Memory) -[r@{type}]-> (b@Memory)
- WHERE a.id = {self._format_value(source_id)} AND b.id = {self._format_value(target_id)}
- """
-
- query += f" AND a.user_name = {self._format_value(user_name)} AND b.user_name = {self._format_value(user_name)}"
- query += "\nDELETE r"
- self.execute_query(query)
-
- @timed
- def get_memory_count(self, memory_type: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{memory_type}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN COUNT(n) AS count"
-
- try:
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
- except Exception as e:
- logger.error(f"[get_memory_count] Failed: {e}")
- return -1
-
- @timed
- def count_nodes(self, scope: str, user_name: str | None = None) -> int:
- user_name = user_name if user_name else self.config.user_name
- query = f"""
- MATCH (n@Memory)
- WHERE n.memory_type = "{scope}"
- """
- query += f"\nAND n.user_name = '{user_name}'"
- query += "\nRETURN count(n) AS count"
-
- result = self.execute_query(query)
- return result.one_or_none()["count"].value
-
- @timed
- def edge_exists(
- self,
- source_id: str,
- target_id: str,
- type: str = "ANY",
- direction: str = "OUTGOING",
- user_name: str | None = None,
- ) -> bool:
- """
- Check if an edge exists between two nodes.
- Args:
- source_id: ID of the source node.
- target_id: ID of the target node.
- type: Relationship type. Use "ANY" to match any relationship type.
- direction: Direction of the edge.
- Use "OUTGOING" (default), "INCOMING", or "ANY".
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- True if the edge exists, otherwise False.
- """
- # Prepare the relationship pattern
- user_name = user_name if user_name else self.config.user_name
- rel = "r" if type == "ANY" else f"r@{type}"
-
- # Prepare the match pattern with direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]->(b@Memory {{id: '{target_id}'}})"
- elif direction == "INCOMING":
- pattern = f"(a@Memory {{id: '{source_id}'}})<-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- elif direction == "ANY":
- pattern = f"(a@Memory {{id: '{source_id}'}})-[{rel}]-(b@Memory {{id: '{target_id}'}})"
- else:
- raise ValueError(
- f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'."
- )
- query = f"MATCH {pattern}"
- query += f"\nWHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
- query += "\nRETURN r"
-
- # Run the Cypher query
- result = self.execute_query(query)
- record = result.one_or_none()
- if record is None:
- return False
- return record.values() is not None
-
- @timed
- # Graph Query & Reasoning
- def get_node(
- self, id: str, include_embedding: bool = False, user_name: str | None = None
- ) -> dict[str, Any] | None:
- """
- Retrieve a Memory node by its unique ID.
-
- Args:
- id (str): Node ID (Memory.id)
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- dict: Node properties as key-value pairs, or None if not found.
- """
- filter_clause = f'n.id = "{id}"'
- return_fields = self._build_return_fields(include_embedding)
- gql = f"""
- MATCH (n@Memory)
- WHERE {filter_clause}
- RETURN {return_fields}
- """
-
- try:
- result = self.execute_query(gql)
- for row in result:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- return node
-
- except Exception as e:
- logger.error(
- f"[get_node] Failed to retrieve node '{id}': {e}, trace: {traceback.format_exc()}"
- )
- return None
-
- @timed
- def get_nodes(
- self,
- ids: list[str],
- include_embedding: bool = False,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict[str, Any]]:
- """
- Retrieve the metadata and memory of a list of nodes.
- Args:
- ids: List of Node identifier.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- list[dict]: Parsed node records containing 'id', 'memory', and 'metadata'.
-
- Notes:
- - Assumes all provided IDs are valid and exist.
- - Returns empty list if input is empty.
- """
- if not ids:
- return []
- # Safe formatting of the ID list
- id_list = ",".join(f'"{_id}"' for _id in ids)
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE n.id IN [{id_list}]
- RETURN {return_fields}
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(
- f"[get_nodes] Failed to retrieve nodes {ids}: {e}, trace: {traceback.format_exc()}"
- )
- return nodes
-
- @timed
- def get_edges(
- self, id: str, type: str = "ANY", direction: str = "ANY", user_name: str | None = None
- ) -> list[dict[str, str]]:
- """
- Get edges connected to a node, with optional type and direction filter.
-
- Args:
- id: Node ID to retrieve edges for.
- type: Relationship type to match, or 'ANY' to match all.
- direction: 'OUTGOING', 'INCOMING', or 'ANY'.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of edges:
- [
- {"from": "source_id", "to": "target_id", "type": "RELATE"},
- ...
- ]
- """
- # Build relationship type filter
- rel_type = "" if type == "ANY" else f"@{type}"
- user_name = user_name if user_name else self.config.user_name
- # Build Cypher pattern based on direction
- if direction == "OUTGOING":
- pattern = f"(a@Memory)-[r{rel_type}]->(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "INCOMING":
- pattern = f"(a@Memory)<-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}'"
- elif direction == "ANY":
- pattern = f"(a@Memory)-[r{rel_type}]-(b@Memory)"
- where_clause = f"a.id = '{id}' OR b.id = '{id}'"
- else:
- raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.")
-
- where_clause += f" AND a.user_name = '{user_name}' AND b.user_name = '{user_name}'"
-
- query = f"""
- MATCH {pattern}
- WHERE {where_clause}
- RETURN a.id AS from_id, b.id AS to_id, type(r) AS edge_type
- """
-
- result = self.execute_query(query)
- edges = []
- for record in result:
- edges.append(
- {
- "from": record["from_id"].value,
- "to": record["to_id"].value,
- "type": record["edge_type"].value,
- }
- )
- return edges
-
- @timed
- def get_neighbors_by_tag(
- self,
- tags: list[str],
- exclude_ids: list[str],
- top_k: int = 5,
- min_overlap: int = 1,
- include_embedding: bool = False,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Find top-K neighbor nodes with maximum tag overlap.
-
- Args:
- tags: The list of tags to match.
- exclude_ids: Node IDs to exclude (e.g., local cluster).
- top_k: Max number of neighbors to return.
- min_overlap: Minimum number of overlapping tags required.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- List of dicts with node details and overlap count.
- """
- if not tags:
- return []
- user_name = user_name if user_name else self.config.user_name
- where_clauses = [
- 'n.status = "activated"',
- 'NOT (n.node_type = "reasoning")',
- 'NOT (n.memory_type = "WorkingMemory")',
- ]
- if exclude_ids:
- where_clauses.append(f"NOT (n.id IN {exclude_ids})")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_clause = " AND ".join(where_clauses)
- tag_list_literal = "[" + ", ".join(f'"{_escape_str(t)}"' for t in tags) + "]"
-
- return_fields = self._build_return_fields(include_embedding)
- query = f"""
- LET tag_list = {tag_list_literal}
-
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- RETURN {return_fields},
- size( filter( n.tags, t -> t IN tag_list ) ) AS overlap_count
- ORDER BY overlap_count DESC
- LIMIT {top_k}
- """
-
- result = self.execute_query(query)
- neighbors: list[dict[str, Any]] = []
- for r in result:
- props = {k: v.value for k, v in r.items() if k != "overlap_count"}
- parsed = self._parse_node(props)
- parsed["overlap_count"] = r["overlap_count"].value
- neighbors.append(parsed)
-
- neighbors.sort(key=lambda x: x["overlap_count"], reverse=True)
- neighbors = neighbors[:top_k]
- result = []
- for neighbor in neighbors[:top_k]:
- neighbor.pop("overlap_count")
- result.append(neighbor)
- return result
-
- @timed
- def get_children_with_embeddings(
- self, id: str, user_name: str | None = None
- ) -> list[dict[str, Any]]:
- user_name = user_name if user_name else self.config.user_name
- where_user = f"AND p.user_name = '{user_name}' AND c.user_name = '{user_name}'"
-
- query = f"""
- MATCH (p@Memory)-[@PARENT]->(c@Memory)
- WHERE p.id = "{id}" {where_user}
- RETURN c.id AS id, c.{self.dim_field} AS {self.dim_field}, c.memory AS memory
- """
- result = self.execute_query(query)
- children = []
- for row in result:
- eid = row["id"].value # STRING
- emb_v = row[self.dim_field].value # NVector
- emb = list(emb_v.values) if emb_v else []
- mem = row["memory"].value # STRING
-
- children.append({"id": eid, "embedding": emb, "memory": mem})
- return children
-
- @timed
- def get_subgraph(
- self,
- center_id: str,
- depth: int = 2,
- center_status: str = "activated",
- user_name: str | None = None,
- ) -> dict[str, Any]:
- """
- Retrieve a local subgraph centered at a given node.
- Args:
- center_id: The ID of the center node.
- depth: The hop distance for neighbors.
- center_status: Required status for center node.
- user_name (str, optional): User name for filtering in non-multi-db mode
- Returns:
- {
- "core_node": {...},
- "neighbors": [...],
- "edges": [...]
- }
- """
- if not 1 <= depth <= 5:
- raise ValueError("depth must be 1-5")
-
- user_name = user_name if user_name else self.config.user_name
-
- gql = f"""
- MATCH (center@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE center.id = '{center_id}'
- AND center.status = '{center_status}'
- AND center.user_name = '{user_name}'
- OPTIONAL MATCH p = (center)-[e]->{{1,{depth}}}(neighbor@Memory)
- WHERE neighbor.user_name = '{user_name}'
- RETURN center,
- collect(DISTINCT neighbor) AS neighbors,
- collect(EDGES(p)) AS edge_chains
- """
-
- result = self.execute_query(gql).one_or_none()
- if not result or result.size == 0:
- return {"core_node": None, "neighbors": [], "edges": []}
-
- core_node_props = result["center"].as_node().get_properties()
- core_node = self._parse_node(core_node_props)
- neighbors = []
- vid_to_id_map = {result["center"].as_node().node_id: core_node["id"]}
- for n in result["neighbors"].value:
- n_node = n.as_node()
- n_props = n_node.get_properties()
- node_parsed = self._parse_node(n_props)
- neighbors.append(node_parsed)
- vid_to_id_map[n_node.node_id] = node_parsed["id"]
-
- edges = []
- for chain_group in result["edge_chains"].value:
- for edge_wr in chain_group.value:
- edge = edge_wr.value
- edges.append(
- {
- "type": edge.get_type(),
- "source": vid_to_id_map.get(edge.get_src_id()),
- "target": vid_to_id_map.get(edge.get_dst_id()),
- }
- )
-
- return {"core_node": core_node, "neighbors": neighbors, "edges": edges}
-
- @timed
- # Search / recall operations
- def search_by_embedding(
- self,
- vector: list[float],
- top_k: int = 5,
- scope: str | None = None,
- status: str | None = None,
- threshold: float | None = None,
- search_filter: dict | None = None,
- user_name: str | None = None,
- **kwargs,
- ) -> list[dict]:
- """
- Retrieve node IDs based on vector similarity.
-
- Args:
- vector (list[float]): The embedding vector representing query semantics.
- top_k (int): Number of top similar nodes to retrieve.
- scope (str, optional): Memory type filter (e.g., 'WorkingMemory', 'LongTermMemory').
- status (str, optional): Node status filter (e.g., 'active', 'archived').
- If provided, restricts results to nodes with matching status.
- threshold (float, optional): Minimum similarity score threshold (0 ~ 1).
- search_filter (dict, optional): Additional metadata filters for search results.
- Keys should match node properties, values are the expected values.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: A list of dicts with 'id' and 'score', ordered by similarity.
-
- Notes:
- - This method uses Neo4j native vector indexing to search for similar nodes.
- - If scope is provided, it restricts results to nodes with matching memory_type.
- - If 'status' is provided, only nodes with the matching status will be returned.
- - If threshold is provided, only results with score >= threshold will be returned.
- - If search_filter is provided, additional WHERE clauses will be added for metadata filtering.
- - Typical use case: restrict to 'status = activated' to avoid
- matching archived or merged nodes.
- """
- user_name = user_name if user_name else self.config.user_name
- vector = _normalize(vector)
- dim = len(vector)
- vector_str = ",".join(f"{float(x)}" for x in vector)
- gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"
- where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
- if scope:
- where_clauses.append(f'n.memory_type = "{scope}"')
- if status:
- where_clauses.append(f'n.status = "{status}"')
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- # Add search_filter conditions
- if search_filter:
- for key, value in search_filter.items():
- if isinstance(value, str):
- where_clauses.append(f'n.{key} = "{value}"')
- else:
- where_clauses.append(f"n.{key} = {value}")
-
- where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""
-
- gql = f"""
- let a = {gql_vector}
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- ORDER BY inner_product(n.{self.dim_field}, a) DESC
- LIMIT {top_k}
- RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
- try:
- result = self.execute_query(gql)
- except Exception as e:
- logger.error(f"[search_by_embedding] Query failed: {e}")
- return []
-
- try:
- output = []
- for row in result:
- values = row.values()
- id_val = values[0].as_string()
- score_val = values[1].as_double()
- score_val = (score_val + 1) / 2 # align to neo4j, Normalized Cosine Score
- if threshold is None or score_val >= threshold:
- output.append({"id": id_val, "score": score_val})
- return output
- except Exception as e:
- logger.error(f"[search_by_embedding] Result parse failed: {e}")
- return []
-
- @timed
- def get_by_metadata(
- self, filters: list[dict[str, Any]], user_name: str | None = None
- ) -> list[str]:
- """
- 1. ADD logic: "AND" vs "OR"(support logic combination);
- 2. Support nested conditional expressions;
-
- Retrieve node IDs that match given metadata filters.
- Supports exact match.
-
- Args:
- filters: List of filter dicts like:
- [
- {"field": "key", "op": "in", "value": ["A", "B"]},
- {"field": "confidence", "op": ">=", "value": 80},
- {"field": "tags", "op": "contains", "value": "AI"},
- ...
- ]
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
-
- Notes:
- - Supports structured querying such as tag/category/importance/time filtering.
- - Can be used for faceted recall or prefiltering before embedding rerank.
- """
- where_clauses = []
- user_name = user_name if user_name else self.config.user_name
- for _i, f in enumerate(filters):
- field = f["field"]
- op = f.get("op", "=")
- value = f["value"]
-
- escaped_value = self._format_value(value)
-
- # Build WHERE clause
- if op == "=":
- where_clauses.append(f"n.{field} = {escaped_value}")
- elif op == "in":
- where_clauses.append(f"n.{field} IN {escaped_value}")
- elif op == "contains":
- where_clauses.append(f"size(filter(n.{field}, t -> t IN {escaped_value})) > 0")
- elif op == "starts_with":
- where_clauses.append(f"n.{field} STARTS WITH {escaped_value}")
- elif op == "ends_with":
- where_clauses.append(f"n.{field} ENDS WITH {escaped_value}")
- elif op in [">", ">=", "<", "<="]:
- where_clauses.append(f"n.{field} {op} {escaped_value}")
- else:
- raise ValueError(f"Unsupported operator: {op}")
-
- where_clauses.append(f'n.user_name = "{user_name}"')
-
- where_str = " AND ".join(where_clauses)
- gql = f"MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) WHERE {where_str} RETURN n.id AS id"
- ids = []
- try:
- result = self.execute_query(gql)
- ids = [record["id"].value for record in result]
- except Exception as e:
- logger.error(f"Failed to get metadata: {e}, gql is {gql}")
- return ids
-
- @timed
- def get_grouped_counts(
- self,
- group_fields: list[str],
- where_clause: str = "",
- params: dict[str, Any] | None = None,
- user_name: str | None = None,
- ) -> list[dict[str, Any]]:
- """
- Count nodes grouped by any fields.
-
- Args:
- group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
- where_clause (str, optional): Extra WHERE condition. E.g.,
- "WHERE n.status = 'activated'"
- params (dict, optional): Parameters for WHERE clause.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
- """
- if not group_fields:
- raise ValueError("group_fields cannot be empty")
- user_name = user_name if user_name else self.config.user_name
- # GQL-specific modifications
- user_clause = f"n.user_name = '{user_name}'"
- if where_clause:
- where_clause = where_clause.strip()
- if where_clause.upper().startswith("WHERE"):
- where_clause += f" AND {user_clause}"
- else:
- where_clause = f"WHERE {where_clause} AND {user_clause}"
- else:
- where_clause = f"WHERE {user_clause}"
-
- # Inline parameters if provided
- if params:
- for key, value in params.items():
- # Handle different value types appropriately
- if isinstance(value, str):
- value = f"'{value}'"
- where_clause = where_clause.replace(f"${key}", str(value))
-
- return_fields = []
- group_by_fields = []
-
- for field in group_fields:
- alias = field.replace(".", "_")
- return_fields.append(f"n.{field} AS {alias}")
- group_by_fields.append(alias)
- # Full GQL query construction
- gql = f"""
- MATCH (n /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {", ".join(return_fields)}, COUNT(n) AS count
- """
- result = self.execute_query(gql) # Pure GQL string execution
-
- output = []
- for record in result:
- group_values = {}
- for i, field in enumerate(group_fields):
- value = record.values()[i].as_string()
- group_values[field] = value
- count_value = record["count"].value
- output.append({**group_values, "count": count_value})
-
- return output
-
- @timed
- def clear(self, user_name: str | None = None) -> None:
- """
- Clear the entire graph if the target database exists.
-
- Args:
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- try:
- query = f"MATCH (n@Memory) WHERE n.user_name = '{user_name}' DETACH DELETE n"
- self.execute_query(query)
- logger.info("Cleared all nodes from database.")
-
- except Exception as e:
- logger.error(f"[ERROR] Failed to clear database: {e}")
-
- @timed
- def export_graph(
- self, include_embedding: bool = False, user_name: str | None = None, **kwargs
- ) -> dict[str, Any]:
- """
- Export all graph nodes and edges in a structured form.
- Args:
- include_embedding (bool): Whether to include the large embedding field.
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- {
- "nodes": [ { "id": ..., "memory": ..., "metadata": {...} }, ... ],
- "edges": [ { "source": ..., "target": ..., "type": ... }, ... ]
- }
- """
- user_name = user_name if user_name else self.config.user_name
- node_query = "MATCH (n@Memory)"
- edge_query = "MATCH (a@Memory)-[r]->(b@Memory)"
- node_query += f' WHERE n.user_name = "{user_name}"'
- edge_query += f' WHERE r.user_name = "{user_name}"'
-
- try:
- if include_embedding:
- return_fields = "n"
- else:
- return_fields = ",".join(
- [
- "n.id AS id",
- "n.memory AS memory",
- "n.user_name AS user_name",
- "n.user_id AS user_id",
- "n.session_id AS session_id",
- "n.status AS status",
- "n.key AS key",
- "n.confidence AS confidence",
- "n.tags AS tags",
- "n.created_at AS created_at",
- "n.updated_at AS updated_at",
- "n.memory_type AS memory_type",
- "n.sources AS sources",
- "n.source AS source",
- "n.node_type AS node_type",
- "n.visibility AS visibility",
- "n.usage AS usage",
- "n.background AS background",
- ]
- )
-
- full_node_query = f"{node_query} RETURN {return_fields}"
- node_result = self.execute_query(full_node_query, timeout=20)
- nodes = []
- logger.debug(f"Debugging: {node_result}")
- for row in node_result:
- if include_embedding:
- props = row.values()[0].as_node().get_properties()
- else:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- nodes.append(node)
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - NODES] Exception: {e}") from e
-
- try:
- full_edge_query = f"{edge_query} RETURN a.id AS source, b.id AS target, type(r) as edge"
- edge_result = self.execute_query(full_edge_query, timeout=20)
- edges = [
- {
- "source": row.values()[0].value,
- "target": row.values()[1].value,
- "type": row.values()[2].value,
- }
- for row in edge_result
- ]
- except Exception as e:
- raise RuntimeError(f"[EXPORT GRAPH - EDGES] Exception: {e}") from e
-
- return {"nodes": nodes, "edges": edges}
-
- @timed
- def import_graph(self, data: dict[str, Any], user_name: str | None = None) -> None:
- """
- Import the entire graph from a serialized dictionary.
-
- Args:
- data: A dictionary containing all nodes and edges to be loaded.
- user_name (str, optional): User name for filtering in non-multi-db mode
- """
- user_name = user_name if user_name else self.config.user_name
- for node in data.get("nodes", []):
- try:
- id, memory, metadata = _compose_node(node)
- metadata["user_name"] = user_name
- metadata = self._prepare_node_metadata(metadata)
- metadata.update({"id": id, "memory": memory})
- properties = ", ".join(
- f"{k}: {self._format_value(v, k)}" for k, v in metadata.items()
- )
- node_gql = f"INSERT OR IGNORE (n@Memory {{{properties}}})"
- self.execute_query(node_gql)
- except Exception as e:
- logger.error(f"Fail to load node: {node}, error: {e}")
-
- for edge in data.get("edges", []):
- try:
- source_id, target_id = edge["source"], edge["target"]
- edge_type = edge["type"]
- props = f'{{user_name: "{user_name}"}}'
- edge_gql = f'''
- MATCH (a@Memory {{id: "{source_id}"}}), (b@Memory {{id: "{target_id}"}})
- INSERT OR IGNORE (a) -[e@{edge_type} {props}]-> (b)
- '''
- self.execute_query(edge_gql)
- except Exception as e:
- logger.error(f"Fail to load edge: {edge}, error: {e}")
-
- @timed
- def get_all_memory_items(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> (list)[dict]:
- """
- Retrieve all memory items of a specific memory_type.
-
- Args:
- scope (str): Must be one of 'WorkingMemory', 'LongTermMemory', or 'UserMemory'.
- include_embedding: with/without embedding
- user_name (str, optional): User name for filtering in non-multi-db mode
-
- Returns:
- list[dict]: Full list of memory items under this scope.
- """
- user_name = user_name if user_name else self.config.user_name
- if scope not in {"WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"}:
- raise ValueError(f"Unsupported memory type scope: {scope}")
-
- where_clause = f"WHERE n.memory_type = '{scope}'"
- where_clause += f" AND n.user_name = '{user_name}'"
-
- return_fields = self._build_return_fields(include_embedding)
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- {where_clause}
- RETURN {return_fields}
- LIMIT 100
- """
- nodes = []
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- nodes.append(self._parse_node(props))
- except Exception as e:
- logger.error(f"Failed to get memories: {e}")
- return nodes
-
- @timed
- def get_structure_optimization_candidates(
- self, scope: str, include_embedding: bool = False, user_name: str | None = None
- ) -> list[dict]:
- """
- Find nodes that are likely candidates for structure optimization:
- - Isolated nodes, nodes with empty background, or nodes with exactly one child.
- - Plus: the child of any parent node that has exactly one child.
- """
- user_name = user_name if user_name else self.config.user_name
- where_clause = f'''
- n.memory_type = "{scope}"
- AND n.status = "activated"
- '''
- where_clause += f' AND n.user_name = "{user_name}"'
-
- return_fields = self._build_return_fields(include_embedding)
- return_fields += f", n.{self.dim_field} AS {self.dim_field}"
-
- query = f"""
- MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
- WHERE {where_clause}
- OPTIONAL MATCH (n)-[@PARENT]->(c@Memory)
- OPTIONAL MATCH (p@Memory)-[@PARENT]->(n)
- WHERE c IS NULL AND p IS NULL
- RETURN {return_fields}
- """
-
- candidates = []
- node_ids = set()
- try:
- results = self.execute_query(query)
- for row in results:
- props = {k: v.value for k, v in row.items()}
- node = self._parse_node(props)
- node_id = node["id"]
- if node_id not in node_ids:
- candidates.append(node)
- node_ids.add(node_id)
- except Exception as e:
- logger.error(f"Failed : {e}, traceback: {traceback.format_exc()}")
- return candidates
-
- @timed
- def drop_database(self) -> None:
- """
- Permanently delete the entire database this instance is using.
- WARNING: This operation is destructive and cannot be undone.
- """
- raise ValueError(
- f"Refusing to drop protected database: `{self.db_name}` in "
- f"Shared Database Multi-Tenant mode"
- )
-
- @timed
- def detect_conflicts(self) -> list[tuple[str, str]]:
- """
- Detect conflicting nodes based on logical or semantic inconsistency.
- Returns:
- A list of (node_id1, node_id2) tuples that conflict.
- """
- raise NotImplementedError
-
- @timed
- # Structure Maintenance
- def deduplicate_nodes(self) -> None:
- """
- Deduplicate redundant or semantically similar nodes.
- This typically involves identifying nodes with identical or near-identical memory.
- """
- raise NotImplementedError
-
- @timed
- def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]:
- """
- Get the ordered context chain starting from a node, following a relationship type.
- Args:
- id: Starting node ID.
- type: Relationship type to follow (e.g., 'FOLLOWS').
- Returns:
- List of ordered node IDs in the chain.
- """
- raise NotImplementedError
-
- @timed
- def get_neighbors(
- self, id: str, type: str, direction: Literal["in", "out", "both"] = "out"
- ) -> list[str]:
- """
- Get connected node IDs in a specific direction and relationship type.
- Args:
- id: Source node ID.
- type: Relationship type.
- direction: Edge direction to follow ('out', 'in', or 'both').
- Returns:
- List of neighboring node IDs.
- """
- raise NotImplementedError
-
- @timed
- def get_path(self, source_id: str, target_id: str, max_depth: int = 3) -> list[str]:
- """
- Get the path of nodes from source to target within a limited depth.
- Args:
- source_id: Starting node ID.
- target_id: Target node ID.
- max_depth: Maximum path length to traverse.
- Returns:
- Ordered list of node IDs along the path.
- """
- raise NotImplementedError
-
- @timed
- def merge_nodes(self, id1: str, id2: str) -> str:
- """
- Merge two similar or duplicate nodes into one.
- Args:
- id1: First node ID.
- id2: Second node ID.
- Returns:
- ID of the resulting merged node.
- """
- raise NotImplementedError
-
- @classmethod
- def _ensure_space_exists(cls, tmp_client, cfg):
- """Lightweight check to ensure target graph (space) exists."""
- db_name = getattr(cfg, "space", None)
- if not db_name:
- logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
- return
-
- try:
- res = tmp_client.execute("SHOW GRAPHS")
- existing = {row.values()[0].as_string() for row in res}
- if db_name not in existing:
- tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type")
- logger.info(f"โ
Graph `{db_name}` created before session binding.")
- else:
- logger.debug(f"Graph `{db_name}` already exists.")
- except Exception:
- logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")
-
- @timed
- def _ensure_database_exists(self):
- graph_type_name = "MemOSBgeM3Type"
-
- check_type_query = "SHOW GRAPH TYPES"
- result = self.execute_query(check_type_query, auto_set_db=False)
-
- type_exists = any(row["graph_type"].as_string() == graph_type_name for row in result)
-
- if not type_exists:
- create_tag = f"""
- CREATE GRAPH TYPE IF NOT EXISTS {graph_type_name} AS {{
- NODE Memory (:MemoryTag {{
- id STRING,
- memory STRING,
- user_name STRING,
- user_id STRING,
- session_id STRING,
- status STRING,
- key STRING,
- confidence FLOAT,
- tags LIST,
- created_at STRING,
- updated_at STRING,
- memory_type STRING,
- sources LIST,
- source STRING,
- node_type STRING,
- visibility STRING,
- usage LIST,
- background STRING,
- {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT>,
- PRIMARY KEY(id)
- }}),
- EDGE RELATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE PARENT (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE AGGREGATE_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE MERGED_TO (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE INFERS (Memory) -[{{user_name STRING}}]-> (Memory),
- EDGE FOLLOWS (Memory) -[{{user_name STRING}}]-> (Memory)
- }}
- """
- self.execute_query(create_tag, auto_set_db=False)
- else:
- describe_query = f"DESCRIBE NODE TYPE Memory OF {graph_type_name}"
- desc_result = self.execute_query(describe_query, auto_set_db=False)
-
- memory_fields = []
- for row in desc_result:
- field_name = row.values()[0].as_string()
- memory_fields.append(field_name)
-
- if self.dim_field not in memory_fields:
- alter_query = f"""
- ALTER GRAPH TYPE {graph_type_name} {{
- ALTER NODE TYPE Memory ADD PROPERTIES {{ {self.dim_field} VECTOR<{self.embedding_dimension}, FLOAT> }}
- }}
- """
- self.execute_query(alter_query, auto_set_db=False)
- logger.info(f"โ
Add new vector search {self.dim_field} to {graph_type_name}")
- else:
- logger.info(f"โ
Graph Type {graph_type_name} already include {self.dim_field}")
-
- create_graph = f"CREATE GRAPH IF NOT EXISTS `{self.db_name}` TYPED {graph_type_name}"
- try:
- self.execute_query(create_graph, auto_set_db=False)
- logger.info(f"โ
Graph ``{self.db_name}`` is now the working graph.")
- except Exception as e:
- logger.error(f"โ Failed to create tag: {e} trace: {traceback.format_exc()}")
-
- @timed
- def _create_vector_index(
- self,
- label: str = "Memory",
- vector_property: str = "embedding",
- dimensions: int = 3072,
- index_name: str = "memory_vector_index",
- ) -> None:
- """
- Create a vector index for the specified property in the label.
- """
- if str(dimensions) == str(self.default_memory_dimension):
- index_name = f"idx_{vector_property}"
- vector_name = vector_property
- else:
- index_name = f"idx_{vector_property}_{dimensions}"
- vector_name = f"{vector_property}_{dimensions}"
-
- create_vector_index = f"""
- CREATE VECTOR INDEX IF NOT EXISTS {index_name}
- ON NODE {label}::{vector_name}
- OPTIONS {{
- DIM: {dimensions},
- METRIC: IP,
- TYPE: IVF,
- NLIST: 100,
- TRAINSIZE: 1000
- }}
- FOR `{self.db_name}`
- """
- self.execute_query(create_vector_index)
- logger.info(
- f"โ
Ensure {label}::{vector_property} vector index {index_name} "
- f"exists (DIM={dimensions})"
- )
-
- @timed
- def _create_basic_property_indexes(self) -> None:
- """
- Create standard B-tree indexes on status, memory_type, created_at
- and updated_at fields.
- Create standard B-tree indexes on user_name when use Shared Database
- Multi-Tenant Mode.
- """
- fields = [
- "status",
- "memory_type",
- "created_at",
- "updated_at",
- "user_name",
- ]
-
- for field in fields:
- index_name = f"idx_memory_{field}"
- gql = f"""
- CREATE INDEX IF NOT EXISTS {index_name} ON NODE Memory({field})
- FOR `{self.db_name}`
- """
- try:
- self.execute_query(gql)
- logger.info(f"โ
Created index: {index_name} on field {field}")
- except Exception as e:
- logger.error(
- f"โ Failed to create index {index_name}: {e}, trace: {traceback.format_exc()}"
- )
-
- @timed
- def _index_exists(self, index_name: str) -> bool:
- """
- Check if an index with the given name exists.
- """
- """
- Check if a vector index with the given name exists in NebulaGraph.
-
- Args:
- index_name (str): The name of the index to check.
-
- Returns:
- bool: True if the index exists, False otherwise.
- """
- query = "SHOW VECTOR INDEXES"
- try:
- result = self.execute_query(query)
- return any(row.values()[0].as_string() == index_name for row in result)
- except Exception as e:
- logger.error(f"[Nebula] Failed to check index existence: {e}")
- return False
-
- @timed
- def _parse_value(self, value: Any) -> Any:
- """turn Nebula ValueWrapper to Python type"""
- from nebulagraph_python.value_wrapper import ValueWrapper
-
- if value is None or (hasattr(value, "is_null") and value.is_null()):
- return None
- try:
- prim = value.cast_primitive() if isinstance(value, ValueWrapper) else value
- except Exception as e:
- logger.warning(f"Error when decode Nebula ValueWrapper: {e}")
- prim = value.cast() if isinstance(value, ValueWrapper) else value
-
- if isinstance(prim, ValueWrapper):
- return self._parse_value(prim)
- if isinstance(prim, list):
- return [self._parse_value(v) for v in prim]
- if type(prim).__name__ == "NVector":
- return list(prim.values)
-
- return prim # already a Python primitive
-
- def _parse_node(self, props: dict[str, Any]) -> dict[str, Any]:
- parsed = {k: self._parse_value(v) for k, v in props.items()}
-
- for tf in ("created_at", "updated_at"):
- if tf in parsed and parsed[tf] is not None:
- parsed[tf] = _normalize_datetime(parsed[tf])
-
- node_id = parsed.pop("id")
- memory = parsed.pop("memory", "")
- parsed.pop("user_name", None)
- metadata = parsed
- metadata["type"] = metadata.pop("node_type")
-
- if self.dim_field in metadata:
- metadata["embedding"] = metadata.pop(self.dim_field)
-
- return {"id": node_id, "memory": memory, "metadata": metadata}
-
- @timed
- def _prepare_node_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Ensure metadata has proper datetime fields and normalized types.
-
- - Fill `created_at` and `updated_at` if missing (in ISO 8601 format).
- - Convert embedding to list of float if present.
- """
- now = datetime.utcnow().isoformat()
- metadata["node_type"] = metadata.pop("type")
-
- # Fill timestamps if missing
- metadata.setdefault("created_at", now)
- metadata.setdefault("updated_at", now)
-
- # Normalize embedding type
- embedding = metadata.get("embedding")
- if embedding and isinstance(embedding, list):
- metadata.pop("embedding")
- metadata[self.dim_field] = _normalize([float(x) for x in embedding])
-
- return metadata
-
- @timed
- def _format_value(self, val: Any, key: str = "") -> str:
- from nebulagraph_python.py_data_types import NVector
-
- # None
- if val is None:
- return "NULL"
- # bool
- if isinstance(val, bool):
- return "true" if val else "false"
- # str
- if isinstance(val, str):
- return f'"{_escape_str(val)}"'
- # num
- elif isinstance(val, (int | float)):
- return str(val)
- # time
- elif isinstance(val, datetime):
- return f'datetime("{val.isoformat()}")'
- # list
- elif isinstance(val, list):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- return f"[{', '.join(self._format_value(v) for v in val)}]"
- # NVector
- elif isinstance(val, NVector):
- if key == self.dim_field:
- dim = len(val)
- joined = ",".join(str(float(x)) for x in val)
- return f"VECTOR<{dim}, FLOAT>([{joined}])"
- else:
- logger.warning("Invalid NVector")
- # dict
- if isinstance(val, dict):
- j = json.dumps(val, ensure_ascii=False, separators=(",", ":"))
- return f'"{_escape_str(j)}"'
- else:
- return f'"{_escape_str(str(val))}"'
-
- @timed
- def _metadata_filter(self, metadata: dict[str, Any]) -> dict[str, Any]:
- """
- Filter and validate metadata dictionary against the Memory node schema.
- - Removes keys not in schema.
- - Warns if required fields are missing.
- """
-
- dim_fields = {self.dim_field}
-
- allowed_fields = self.common_fields | dim_fields
-
- missing_fields = allowed_fields - metadata.keys()
- if missing_fields:
- logger.info(f"Metadata missing required fields: {sorted(missing_fields)}")
-
- filtered_metadata = {k: v for k, v in metadata.items() if k in allowed_fields}
-
- return filtered_metadata
-
- def _build_return_fields(self, include_embedding: bool = False) -> str:
- fields = set(self.base_fields)
- if include_embedding:
- fields.add(self.dim_field)
- return ", ".join(f"n.{f} AS {f}" for f in fields)
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index 816dcd86d..31aab9407 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -443,7 +443,7 @@ def remove_oldest_memory(
)
user_name = user_name if user_name else self._get_config_value("user_name")
- # Use actual OFFSET logic, consistent with nebular.py
+ # Use actual OFFSET logic for deterministic pruning
# First find IDs to delete, then delete them
select_query = f"""
SELECT id FROM "{self.db_name}_graph"."Memory"
@@ -3531,7 +3531,7 @@ def get_neighbors_by_tag_ccl(
user_name = user_name if user_name else self._get_config_value("user_name")
- # Build query conditions; keep consistent with nebular.py
+ # Build query conditions shared with other graph backends
where_clauses = [
'n.status = "activated"',
'NOT (n.node_type = "reasoning")',
@@ -3584,7 +3584,7 @@ def get_neighbors_by_tag_ccl(
# Add overlap_count
result_fields.append("overlap_count agtype")
result_fields_str = ", ".join(result_fields)
- # Use Cypher query; keep consistent with nebular.py
+ # Use Cypher query to keep the graph query path aligned
query = f"""
SELECT * FROM (
SELECT * FROM cypher('{self.db_name}_graph', $$
diff --git a/src/memos/mem_os/client.py b/src/memos/mem_os/client.py
deleted file mode 100644
index f4a591e59..000000000
--- a/src/memos/mem_os/client.py
+++ /dev/null
@@ -1,5 +0,0 @@
-# TODO: @Li Ji
-
-
-class ClientMOS:
- pass
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
deleted file mode 100644
index b2c74c384..000000000
--- a/src/memos/mem_os/product.py
+++ /dev/null
@@ -1,1610 +0,0 @@
-import asyncio
-import json
-import os
-import random
-import time
-
-from collections.abc import Generator
-from datetime import datetime
-from typing import Any, Literal
-
-from dotenv import load_dotenv
-from transformers import AutoTokenizer
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.context.context import ContextThread
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.core import MOSCore
-from memos.mem_os.utils.format_utils import (
- clean_json_response,
- convert_graph_to_tree_forworkmem,
- ensure_unique_tree_ids,
- filter_nodes_by_tree_ids,
- remove_embedding_recursive,
- sort_children_by_memory_type,
-)
-from memos.mem_os.utils.reference_utils import (
- prepare_reference_data,
- process_streaming_references_complete,
-)
-from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
-from memos.mem_scheduler.schemas.task_schemas import (
- ANSWER_TASK_LABEL,
- QUERY_TASK_LABEL,
-)
-from memos.mem_user.persistent_factory import PersistentUserManagerFactory
-from memos.mem_user.user_manager import UserRole
-from memos.memories.textual.item import (
- TextualMemoryItem,
-)
-from memos.templates.mos_prompts import (
- FURTHER_SUGGESTION_PROMPT,
- SUGGESTION_QUERY_PROMPT_EN,
- SUGGESTION_QUERY_PROMPT_ZH,
- get_memos_prompt,
-)
-from memos.types import MessageList
-from memos.utils import timed
-
-
-logger = get_logger(__name__)
-
-load_dotenv()
-
-CUBE_PATH = os.getenv("MOS_CUBE_PATH", "/tmp/data/")
-
-
-def _short_id(mem_id: str) -> str:
- return (mem_id or "").split("-")[0] if mem_id else ""
-
-
-def _format_mem_block(memories_all, max_items: int = 20, max_chars_each: int = 320) -> str:
- """
- Modify TextualMemoryItem Format:
- 1:abcd :: [P] text...
- 2:ef01 :: [O] text...
- sequence is [i:memId] i; [P]=PersonalMemory / [O]=OuterMemory
- """
- if not memories_all:
- return "(none)", "(none)"
-
- lines_o = []
- lines_p = []
- for idx, m in enumerate(memories_all[:max_items], 1):
- mid = _short_id(getattr(m, "id", "") or "")
- mtype = getattr(getattr(m, "metadata", {}), "memory_type", None) or getattr(
- m, "metadata", {}
- ).get("memory_type", "")
- tag = "O" if "Outer" in str(mtype) else "P"
- txt = (getattr(m, "memory", "") or "").replace("\n", " ").strip()
- if len(txt) > max_chars_each:
- txt = txt[: max_chars_each - 1] + "โฆ"
- mid = mid or f"mem_{idx}"
- if tag == "O":
- lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n")
- elif tag == "P":
- lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}")
- return "\n".join(lines_o), "\n".join(lines_p)
-
-
-class MOSProduct(MOSCore):
- """
- The MOSProduct class inherits from MOSCore and manages multiple users.
- Each user has their own configuration and cube access, but shares the same model instances.
- """
-
- def __init__(
- self,
- default_config: MOSConfig | None = None,
- max_user_instances: int = 1,
- default_cube_config: GeneralMemCubeConfig | None = None,
- online_bot=None,
- error_bot=None,
- ):
- """
- Initialize MOSProduct with an optional default configuration.
-
- Args:
- default_config (MOSConfig | None): Default configuration for new users
- max_user_instances (int): Maximum number of user instances to keep in memory
- default_cube_config (GeneralMemCubeConfig | None): Default cube configuration for loading cubes
- online_bot: DingDing online_bot function or None if disabled
- error_bot: DingDing error_bot function or None if disabled
- """
- # Initialize with a root config for shared resources
- if default_config is None:
- # Create a minimal config for root user
- root_config = MOSConfig(
- user_id="root",
- session_id="root_session",
- chat_model=default_config.chat_model if default_config else None,
- mem_reader=default_config.mem_reader if default_config else None,
- enable_mem_scheduler=default_config.enable_mem_scheduler
- if default_config
- else False,
- mem_scheduler=default_config.mem_scheduler if default_config else None,
- )
- else:
- root_config = default_config.model_copy(deep=True)
- root_config.user_id = "root"
- root_config.session_id = "root_session"
-
- # Create persistent user manager BEFORE calling parent constructor
- persistent_user_manager_client = PersistentUserManagerFactory.from_config(
- config_factory=root_config.user_manager
- )
-
- # Initialize parent MOSCore with root config and persistent user manager
- super().__init__(root_config, user_manager=persistent_user_manager_client)
-
- # Product-specific attributes
- self.default_config = default_config
- self.default_cube_config = default_cube_config
- self.max_user_instances = max_user_instances
- self.online_bot = online_bot
- self.error_bot = error_bot
-
- # User-specific data structures
- self.user_configs: dict[str, MOSConfig] = {}
- self.user_cube_access: dict[str, set[str]] = {} # user_id -> set of cube_ids
- self.user_chat_histories: dict[str, dict] = {}
-
- # Note: self.user_manager is now the persistent user manager from parent class
- # No need for separate global_user_manager as they are the same instance
-
- # Initialize tiktoken for streaming
- try:
- # Use gpt2 encoding which is more stable and widely compatible
- self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
- logger.info("tokenizer initialized successfully for streaming")
- except Exception as e:
- logger.warning(
- f"Failed to initialize tokenizer, will use character-based chunking: {e}"
- )
- self.tokenizer = None
-
- # Restore user instances from persistent storage
- self._restore_user_instances(default_cube_config=default_cube_config)
- logger.info(f"User instances restored successfully, now user is {self.mem_cubes.keys()}")
-
- def _restore_user_instances(
- self, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Restore user instances from persistent storage after service restart.
-
- Args:
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- try:
- # Get all user configurations from persistent storage
- user_configs = self.user_manager.list_user_configs(self.max_user_instances)
-
- # Get the raw database records for sorting by updated_at
- session = self.user_manager._get_session()
- try:
- from memos.mem_user.persistent_user_manager import UserConfig
-
- db_configs = session.query(UserConfig).limit(self.max_user_instances).all()
- # Create a mapping of user_id to updated_at timestamp
- updated_at_map = {config.user_id: config.updated_at for config in db_configs}
-
- # Sort by updated_at timestamp (most recent first) and limit by max_instances
- sorted_configs = sorted(
- user_configs.items(), key=lambda x: updated_at_map.get(x[0], ""), reverse=True
- )[: self.max_user_instances]
- finally:
- session.close()
-
- for user_id, config in sorted_configs:
- if user_id != "root": # Skip root user
- try:
- # Store user config and cube access
- self.user_configs[user_id] = config
- self._load_user_cube_access(user_id)
-
- # Pre-load all cubes for this user with default config
- self._preload_user_cubes(user_id, default_cube_config)
-
- logger.info(
- f"Restored user configuration and pre-loaded cubes for {user_id}"
- )
-
- except Exception as e:
- logger.error(f"Failed to restore user configuration for {user_id}: {e}")
-
- except Exception as e:
- logger.error(f"Error during user instance restoration: {e}")
-
- def _initialize_cube_from_default_config(
- self, cube_id: str, user_id: str, default_config: GeneralMemCubeConfig
- ) -> GeneralMemCube | None:
- """
- Initialize a cube from default configuration when cube path doesn't exist.
-
- Args:
- cube_id (str): The cube ID to initialize.
- user_id (str): The user ID for the cube.
- default_config (GeneralMemCubeConfig): The default configuration to use.
- """
- cube_config = default_config.model_copy(deep=True)
- # Safely modify the graph_db user_name if it exists
- if cube_config.text_mem.config.graph_db.config:
- cube_config.text_mem.config.graph_db.config.user_name = (
- f"memos{user_id.replace('-', '')}"
- )
- mem_cube = GeneralMemCube(config=cube_config)
- return mem_cube
-
- def _preload_user_cubes(
- self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Pre-load all cubes for a user into memory.
-
- Args:
- user_id (str): The user ID to pre-load cubes for.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- try:
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cubes(user_id)
-
- for cube in accessible_cubes:
- if cube.cube_id not in self.mem_cubes:
- try:
- if cube.cube_path and os.path.exists(cube.cube_path):
- # Pre-load cube with all memory types and default config
- self.register_mem_cube(
- cube.cube_path,
- cube.cube_id,
- user_id,
- memory_types=["act_mem"]
- if self.config.enable_activation_memory
- else [],
- default_config=default_cube_config,
- )
- logger.info(f"Pre-loaded cube {cube.cube_id} for user {user_id}")
- else:
- logger.warning(
- f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, skipping pre-load"
- )
- except Exception as e:
- logger.error(
- f"Failed to pre-load cube {cube.cube_id} for user {user_id}: {e}",
- exc_info=True,
- )
-
- except Exception as e:
- logger.error(f"Error pre-loading cubes for user {user_id}: {e}", exc_info=True)
-
- @timed
- def _load_user_cubes(
- self, user_id: str, default_cube_config: GeneralMemCubeConfig | None = None
- ) -> None:
- """Load all cubes for a user into memory.
-
- Args:
- user_id (str): The user ID to load cubes for.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
- """
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cubes(user_id)
-
- for cube in accessible_cubes[:1]:
- if cube.cube_id not in self.mem_cubes:
- try:
- if cube.cube_path and os.path.exists(cube.cube_path):
- # Use MOSCore's register_mem_cube method directly with default config
- # Only load act_mem since text_mem is stored in database
- self.register_mem_cube(
- cube.cube_path,
- cube.cube_id,
- user_id,
- memory_types=["act_mem"],
- default_config=default_cube_config,
- )
- else:
- logger.warning(
- f"Cube path {cube.cube_path} does not exist for cube {cube.cube_id}, now init by default config"
- )
- cube_obj = self._initialize_cube_from_default_config(
- cube_id=cube.cube_id,
- user_id=user_id,
- default_config=default_cube_config,
- )
- if cube_obj:
- self.register_mem_cube(
- cube_obj,
- cube.cube_id,
- user_id,
- memory_types=[],
- )
- else:
- raise ValueError(
- f"Failed to initialize default cube {cube.cube_id} for user {user_id}"
- )
- except Exception as e:
- logger.error(f"Failed to load cube {cube.cube_id} for user {user_id}: {e}")
- logger.info(f"load user {user_id} cubes successfully")
-
- def _ensure_user_instance(self, user_id: str, max_instances: int | None = None) -> None:
- """
- Ensure user configuration exists, creating it if necessary.
-
- Args:
- user_id (str): The user ID
- max_instances (int): Maximum instances to keep in memory (overrides class default)
- """
- if user_id in self.user_configs:
- return
-
- # Try to get config from persistent storage first
- stored_config = self.user_manager.get_user_config(user_id)
- if stored_config:
- self.user_configs[user_id] = stored_config
- self._load_user_cube_access(user_id)
- else:
- # Use default config
- if not self.default_config:
- raise ValueError(f"No configuration available for user {user_id}")
- user_config = self.default_config.model_copy(deep=True)
- user_config.user_id = user_id
- user_config.session_id = f"{user_id}_session"
- self.user_configs[user_id] = user_config
- self._load_user_cube_access(user_id)
-
- # Apply LRU eviction if needed
- max_instances = max_instances or self.max_user_instances
- if len(self.user_configs) > max_instances:
- # Remove least recently used instance (excluding root)
- user_ids = [uid for uid in self.user_configs if uid != "root"]
- if user_ids:
- oldest_user_id = user_ids[0]
- del self.user_configs[oldest_user_id]
- if oldest_user_id in self.user_cube_access:
- del self.user_cube_access[oldest_user_id]
- logger.info(f"Removed least recently used user configuration: {oldest_user_id}")
-
- def _load_user_cube_access(self, user_id: str) -> None:
- """Load user's cube access permissions."""
- try:
- # Get user's accessible cubes from persistent storage
- accessible_cubes = self.user_manager.get_user_cube_access(user_id)
- self.user_cube_access[user_id] = set(accessible_cubes)
- except Exception as e:
- logger.warning(f"Failed to load cube access for user {user_id}: {e}")
- self.user_cube_access[user_id] = set()
-
- def _get_user_config(self, user_id: str) -> MOSConfig:
- """Get user configuration."""
- if user_id not in self.user_configs:
- self._ensure_user_instance(user_id)
- return self.user_configs[user_id]
-
- def _validate_user_cube_access(self, user_id: str, cube_id: str) -> None:
- """Validate user has access to the cube."""
- if user_id not in self.user_cube_access:
- self._load_user_cube_access(user_id)
-
- if cube_id not in self.user_cube_access.get(user_id, set()):
- raise ValueError(f"User '{user_id}' does not have access to cube '{cube_id}'")
-
- def _validate_user_access(self, user_id: str, cube_id: str | None = None) -> None:
- """Validate user access using MOSCore's built-in validation."""
- # Use MOSCore's built-in user validation
- if cube_id:
- self._validate_cube_access(user_id, cube_id)
- else:
- self._validate_user_exists(user_id)
-
- def _create_user_config(self, user_id: str, config: MOSConfig) -> MOSConfig:
- """Create a new user configuration."""
- # Create a copy of config with the specific user_id
- user_config = config.model_copy(deep=True)
- user_config.user_id = user_id
- user_config.session_id = f"{user_id}_session"
-
- # Save configuration to persistent storage
- self.user_manager.save_user_config(user_id, user_config)
-
- return user_config
-
- def _get_or_create_user_config(
- self, user_id: str, config: MOSConfig | None = None
- ) -> MOSConfig:
- """Get existing user config or create a new one."""
- if user_id in self.user_configs:
- return self.user_configs[user_id]
-
- # Try to get config from persistent storage first
- stored_config = self.user_manager.get_user_config(user_id)
- if stored_config:
- return self._create_user_config(user_id, stored_config)
-
- # Use provided config or default config
- user_config = config or self.default_config
- if not user_config:
- raise ValueError(f"No configuration provided for user {user_id}")
-
- return self._create_user_config(user_id, user_config)
-
- def _build_system_prompt(
- self,
- memories_all: list[TextualMemoryItem],
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- ) -> str:
- """
- Build custom system prompt for the user with memory references.
-
- Args:
- user_id (str): The user ID.
- memories (list[TextualMemoryItem]): The memories to build the system prompt.
-
- Returns:
- str: The custom system prompt.
- """
- # Build base prompt
- # Add memory context if available
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(
- date=formatted_date, tone=tone, verbosity=verbosity, mode="base"
- )
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
- mem_block = mem_block_o + "\n" + mem_block_p
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return (
- prefix
- + sys_body
- + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n"
- + mem_block
- )
-
- def _build_base_system_prompt(
- self,
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- mode: str = "enhance",
- ) -> str:
- """
- Build base system prompt without memory references.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode)
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return prefix + sys_body
-
- def _build_memory_context(
- self,
- memories_all: list[TextualMemoryItem],
- mode: str = "enhance",
- ) -> str:
- """
- Build memory context to be included in user message.
- """
- if not memories_all:
- return ""
-
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
-
- if mode == "enhance":
- return (
- "# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- + "\n\n"
- )
- else:
- mem_block = mem_block_o + "\n" + mem_block_p
- return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n"
-
- def _build_enhance_system_prompt(
- self,
- user_id: str,
- memories_all: list[TextualMemoryItem],
- tone: str = "friendly",
- verbosity: str = "mid",
- ) -> str:
- """
- Build enhance prompt for the user with memory references.
- [DEPRECATED] Use _build_base_system_prompt and _build_memory_context instead.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(
- date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance"
- )
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
- return (
- sys_body
- + "\n\n# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- )
-
- def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]:
- """
- Extract reference information from the response and return clean text.
-
- Args:
- response (str): The complete response text.
-
- Returns:
- tuple[str, list[dict]]: A tuple containing:
- - clean_text: Text with reference markers removed
- - references: List of reference information
- """
- import re
-
- try:
- references = []
- # Pattern to match [refid:memoriesID]
- pattern = r"\[(\d+):([^\]]+)\]"
-
- matches = re.findall(pattern, response)
- for ref_number, memory_id in matches:
- references.append({"memory_id": memory_id, "reference_number": int(ref_number)})
-
- # Remove all reference markers from the text to get clean text
- clean_text = re.sub(pattern, "", response)
-
- # Clean up any extra whitespace that might be left after removing markers
- clean_text = re.sub(r"\s+", " ", clean_text).strip()
-
- return clean_text, references
- except Exception as e:
- logger.error(f"Error extracting references from response: {e}", exc_info=True)
- return response, []
-
- def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict:
- """
- get struct message from chat-history
- # TODO: @xcy make this more general
- """
- system_content = ""
- memory_content = ""
- chat_history = []
-
- for item in chat_data:
- role = item.get("role")
- content = item.get("content", "")
- if role == "system":
- parts = content.split("# Memories", 1)
- system_content = parts[0].strip()
- if len(parts) > 1:
- memory_content = "# Memories" + parts[1].strip()
- elif role in ("user", "assistant"):
- chat_history.append({"role": role, "content": content})
-
- if chat_history and chat_history[-1]["role"] == "assistant":
- if len(chat_history) >= 2 and chat_history[-2]["role"] == "user":
- chat_history = chat_history[:-2]
- else:
- chat_history = chat_history[:-1]
-
- return {"system": system_content, "memory": memory_content, "chat_history": chat_history}
-
- def _chunk_response_with_tiktoken(
- self, response: str, chunk_size: int = 5
- ) -> Generator[str, None, None]:
- """
- Chunk response using tiktoken for proper token-based streaming.
-
- Args:
- response (str): The response text to chunk.
- chunk_size (int): Number of tokens per chunk.
-
- Yields:
- str: Chunked text pieces.
- """
- if self.tokenizer:
- # Use tiktoken for proper token-based chunking
- tokens = self.tokenizer.encode(response)
-
- for i in range(0, len(tokens), chunk_size):
- token_chunk = tokens[i : i + chunk_size]
- chunk_text = self.tokenizer.decode(token_chunk)
- yield chunk_text
- else:
- # Fallback to character-based chunking
- char_chunk_size = chunk_size * 4 # Approximate character to token ratio
- for i in range(0, len(response), char_chunk_size):
- yield response[i : i + char_chunk_size]
-
- def _send_message_to_scheduler(
- self,
- user_id: str,
- mem_cube_id: str,
- query: str,
- label: str,
- ):
- """
- Send message to scheduler.
- args:
- user_id: str,
- mem_cube_id: str,
- query: str,
- """
-
- if self.enable_mem_scheduler and (self.mem_scheduler is not None):
- message_item = ScheduleMessageItem(
- user_id=user_id,
- mem_cube_id=mem_cube_id,
- label=label,
- content=query,
- timestamp=datetime.utcnow(),
- )
- self.mem_scheduler.submit_messages(messages=[message_item])
-
- async def _post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions
- """
- try:
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
- )
- logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
-
- clean_response, extracted_references = self._extract_references_from_response(
- full_response
- )
- struct_message = self._extract_struct_data_from_history(current_messages)
- logger.info(f"Extracted {len(extracted_references)} references from response")
-
- # Send chat report notifications asynchronously
- if self.online_bot:
- logger.info("Online Bot Open!")
- try:
- from memos.memos_tools.notification_utils import (
- send_online_bot_notification_async,
- )
-
- # Prepare notification data
- chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id}
- chat_data.update(
- {
- "memory": struct_message["memory"],
- "chat_history": struct_message["chat_history"],
- "full_response": full_response,
- }
- )
-
- system_data = {
- "references": extracted_references,
- "time_start": time_start,
- "time_end": time_end,
- "speed_improvement": speed_improvement,
- }
-
- emoji_config = {"chat": "๐ฌ", "system_info": "๐"}
-
- await send_online_bot_notification_async(
- online_bot=self.online_bot,
- header_name="MemOS Chat Report",
- sub_title_name="chat_with_references",
- title_color="#00956D",
- other_data1=chat_data,
- other_data2=system_data,
- emoji=emoji_config,
- )
- except Exception as e:
- logger.warning(f"Failed to send chat notification (async): {e}")
-
- self._send_message_to_scheduler(
- user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_TASK_LABEL
- )
-
- self.add(
- user_id=user_id,
- messages=[
- {
- "role": "user",
- "content": query,
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- {
- "role": "assistant",
- "content": clean_response, # Store clean text without reference markers
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- ],
- mem_cube_id=cube_id,
- )
-
- logger.info(f"Post-chat processing completed for user {user_id}")
-
- except Exception as e:
- logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True)
-
- def _start_post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
- """
- logger.info("Start post_chat_processing...")
-
- def run_async_in_thread():
- """Running asynchronous tasks in a new thread"""
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- finally:
- loop.close()
- except Exception as e:
- logger.error(
- f"Error in thread-based post-chat processing for user {user_id}: {e}",
- exc_info=True,
- )
-
- try:
- # Try to get the current event loop
- asyncio.get_running_loop()
- # Create task and store reference to prevent garbage collection
- task = asyncio.create_task(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- # Add exception handling for the background task
- task.add_done_callback(
- lambda t: (
- logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
- )
- if t.exception()
- else None
- )
- )
- except RuntimeError:
- # No event loop, run in a new thread with context propagation
- thread = ContextThread(
- target=run_async_in_thread,
- name=f"PostChatProcessing-{user_id}",
- # Set as a daemon thread to avoid blocking program exit
- daemon=True,
- )
- thread.start()
-
- def _filter_memories_by_threshold(
- self,
- memories: list[TextualMemoryItem],
- threshold: float = 0.30,
- min_num: int = 3,
- memory_type: Literal["OuterMemory"] = "OuterMemory",
- ) -> list[TextualMemoryItem]:
- """
- Filter memories by threshold and type, at least min_num memories for Non-OuterMemory.
- Args:
- memories: list[TextualMemoryItem],
- threshold: float,
- min_num: int,
- memory_type: Literal["OuterMemory"],
- Returns:
- list[TextualMemoryItem]
- """
- sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True)
- filtered_person = [m for m in memories if m.metadata.memory_type != memory_type]
- filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type]
- filtered = []
- per_memory_count = 0
- for m in sorted_memories:
- if m.metadata.relativity >= threshold:
- if m.metadata.memory_type != memory_type:
- per_memory_count += 1
- filtered.append(m)
- if len(filtered) < min_num:
- filtered = filtered_person[:min_num] + filtered_outer[:min_num]
- else:
- if per_memory_count < min_num:
- filtered += filtered_person[per_memory_count:min_num]
- filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True)
- return filtered_memory
-
- def register_mem_cube(
- self,
- mem_cube_name_or_path_or_object: str | GeneralMemCube,
- mem_cube_id: str | None = None,
- user_id: str | None = None,
- memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None,
- default_config: GeneralMemCubeConfig | None = None,
- ) -> None:
- """
- Register a MemCube with the MOS.
-
- Args:
- mem_cube_name_or_path_or_object (str | GeneralMemCube): The name, path, or GeneralMemCube object to register.
- mem_cube_id (str, optional): The identifier for the MemCube. If not provided, a default ID is used.
- user_id (str, optional): The user ID to register the cube for.
- memory_types (list[str], optional): List of memory types to load.
- If None, loads all available memory types.
- Options: ["text_mem", "act_mem", "para_mem"]
- default_config (GeneralMemCubeConfig, optional): Default configuration for the cube.
- """
- # Handle different input types
- if isinstance(mem_cube_name_or_path_or_object, GeneralMemCube):
- # Direct GeneralMemCube object provided
- mem_cube = mem_cube_name_or_path_or_object
- if mem_cube_id is None:
- mem_cube_id = f"cube_{id(mem_cube)}" # Generate a unique ID
- else:
- # String path provided
- mem_cube_name_or_path = mem_cube_name_or_path_or_object
- if mem_cube_id is None:
- mem_cube_id = mem_cube_name_or_path
-
- if mem_cube_id in self.mem_cubes:
- logger.info(f"MemCube with ID {mem_cube_id} already in MOS, skip install.")
- return
-
- # Create MemCube from path
- time_start = time.time()
- if os.path.exists(mem_cube_name_or_path):
- mem_cube = GeneralMemCube.init_from_dir(
- mem_cube_name_or_path, memory_types, default_config
- )
- logger.info(
- f"time register_mem_cube: init_from_dir time is: {time.time() - time_start}"
- )
- else:
- logger.warning(
- f"MemCube {mem_cube_name_or_path} does not exist, try to init from remote repo."
- )
- mem_cube = GeneralMemCube.init_from_remote_repo(
- mem_cube_name_or_path, memory_types=memory_types, default_config=default_config
- )
-
- # Register the MemCube
- logger.info(
- f"Registering MemCube {mem_cube_id} with cube config {mem_cube.config.model_dump(mode='json')}"
- )
- time_start = time.time()
- self.mem_cubes[mem_cube_id] = mem_cube
- time_end = time.time()
- logger.info(f"time register_mem_cube: add mem_cube time is: {time_end - time_start}")
-
- def user_register(
- self,
- user_id: str,
- user_name: str | None = None,
- config: MOSConfig | None = None,
- interests: str | None = None,
- default_mem_cube: GeneralMemCube | None = None,
- default_cube_config: GeneralMemCubeConfig | None = None,
- mem_cube_id: str | None = None,
- ) -> dict[str, str]:
- """Register a new user with configuration and default cube.
-
- Args:
- user_id (str): The user ID for registration.
- user_name (str): The user name for registration.
- config (MOSConfig | None, optional): User-specific configuration. Defaults to None.
- interests (str | None, optional): User interests as string. Defaults to None.
- default_mem_cube (GeneralMemCube | None, optional): Default memory cube. Defaults to None.
- default_cube_config (GeneralMemCubeConfig | None, optional): Default cube configuration. Defaults to None.
-
- Returns:
- dict[str, str]: Registration result with status and message.
- """
- try:
- # Use provided config or default config
- user_config = config or self.default_config
- if not user_config:
- return {
- "status": "error",
- "message": "No configuration provided for user registration",
- }
- if not user_name:
- user_name = user_id
-
- # Create user with configuration using persistent user manager
- self.user_manager.create_user_with_config(user_id, user_config, UserRole.USER, user_id)
-
- # Create user configuration
- user_config = self._create_user_config(user_id, user_config)
-
- # Create a default cube for the user using MOSCore's methods
- default_cube_name = f"{user_name}_{user_id}_default_cube"
- mem_cube_name_or_path = os.path.join(CUBE_PATH, default_cube_name)
- default_cube_id = self.create_cube_for_user(
- cube_name=default_cube_name,
- owner_id=user_id,
- cube_path=mem_cube_name_or_path,
- cube_id=mem_cube_id,
- )
- time_start = time.time()
- if default_mem_cube:
- try:
- default_mem_cube.dump(mem_cube_name_or_path, memory_types=[])
- except Exception as e:
- logger.error(f"Failed to dump default cube: {e}")
- time_end = time.time()
- logger.info(f"time user_register: dump default cube time is: {time_end - time_start}")
- # Register the default cube with MOS
- self.register_mem_cube(
- mem_cube_name_or_path_or_object=default_mem_cube,
- mem_cube_id=default_cube_id,
- user_id=user_id,
- memory_types=["act_mem"] if self.config.enable_activation_memory else [],
- default_config=default_cube_config, # use default cube config
- )
-
- # Add interests to the default cube if provided
- if interests:
- self.add(memory_content=interests, mem_cube_id=default_cube_id, user_id=user_id)
-
- return {
- "status": "success",
- "message": f"User {user_name} registered successfully with default cube {default_cube_id}",
- "user_id": user_id,
- "default_cube_id": default_cube_id,
- }
-
- except Exception as e:
- return {"status": "error", "message": f"Failed to register user: {e!s}"}
-
- def _get_further_suggestion(self, message: MessageList | None = None) -> list[str]:
- """Get further suggestion prompt."""
- try:
- dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]])
- further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info)
- message_list = [{"role": "system", "content": further_suggestion_prompt}]
- response = self.chat_llm.generate(message_list)
- clean_response = clean_json_response(response)
- response_json = json.loads(clean_response)
- return response_json["query"]
- except Exception as e:
- logger.error(f"Error getting further suggestion: {e}", exc_info=True)
- return []
-
- def get_suggestion_query(
- self, user_id: str, language: str = "zh", message: MessageList | None = None
- ) -> list[str]:
- """Get suggestion query from LLM.
- Args:
- user_id (str): User ID.
- language (str): Language for suggestions ("zh" or "en").
-
- Returns:
- list[str]: The suggestion query list.
- """
- if message:
- further_suggestion = self._get_further_suggestion(message)
- return further_suggestion
- if language == "zh":
- suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH
- else: # English
- suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN
- text_mem_result = super().search("my recently memories", user_id=user_id, top_k=3)[
- "text_mem"
- ]
- if text_mem_result:
- memories = "\n".join([m.memory[:200] for m in text_mem_result[0]["memories"]])
- else:
- memories = ""
- message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}]
- response = self.chat_llm.generate(message_list)
- clean_response = clean_json_response(response)
- response_json = json.loads(clean_response)
- return response_json["query"]
-
- def chat(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- history: MessageList | None = None,
- base_prompt: str | None = None,
- internet_search: bool = False,
- moscube: bool = False,
- top_k: int = 10,
- threshold: float = 0.5,
- session_id: str | None = None,
- ) -> str:
- """
- Chat with LLM with memory references and complete response.
- """
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memories_result = super().search(
- query,
- user_id,
- install_cube_ids=[cube_id] if cube_id else None,
- top_k=top_k,
- mode="fine",
- internet_search=internet_search,
- moscube=moscube,
- session_id=session_id,
- )["text_mem"]
-
- memories_list = []
- if memories_result:
- memories_list = memories_result[0]["memories"]
- memories_list = self._filter_memories_by_threshold(memories_list, threshold)
- new_memories_list = []
- for m in memories_list:
- m.metadata.embedding = []
- new_memories_list.append(m)
- memories_list = new_memories_list
-
- system_prompt = super()._build_system_prompt(memories_list, base_prompt)
- if history is not None:
- # Use the provided history (even if it's empty)
- history_info = history[-20:]
- else:
- # Fall back to internal chat_history
- if user_id not in self.chat_history_manager:
- self._register_chat_history(user_id, session_id)
- history_info = self.chat_history_manager[user_id].chat_history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *history_info,
- {"role": "user", "content": query},
- ]
- logger.info("Start to get final answer...")
- response = self.chat_llm.generate(current_messages)
- time_end = time.time()
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=0.0,
- current_messages=current_messages,
- )
- return response, memories_list
-
- def chat_with_references(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- history: MessageList | None = None,
- top_k: int = 20,
- internet_search: bool = False,
- moscube: bool = False,
- session_id: str | None = None,
- ) -> Generator[str, None, None]:
- """
- Chat with LLM with memory references and streaming output.
-
- Args:
- query (str): Query string.
- user_id (str): User ID.
- cube_id (str, optional): Custom cube ID for user.
- history (MessageList, optional): Chat history.
-
- Returns:
- Generator[str, None, None]: The response string generator with reference processing.
- """
-
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memories_list = []
- yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n"
- memories_result = super().search(
- query,
- user_id,
- install_cube_ids=[cube_id] if cube_id else None,
- top_k=top_k,
- mode="fine",
- internet_search=internet_search,
- moscube=moscube,
- session_id=session_id,
- )["text_mem"]
-
- yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
- search_time_end = time.time()
- logger.info(
- f"time chat: search text_mem time user_id: {user_id} time is: {search_time_end - time_start}"
- )
- self._send_message_to_scheduler(
- user_id=user_id, mem_cube_id=cube_id, query=query, label=QUERY_TASK_LABEL
- )
- if memories_result:
- memories_list = memories_result[0]["memories"]
- memories_list = self._filter_memories_by_threshold(memories_list)
-
- reference = prepare_reference_data(memories_list)
- yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
- # Build custom system prompt with relevant memories)
- system_prompt = self._build_enhance_system_prompt(user_id, memories_list)
- # Get chat history
- if user_id not in self.chat_history_manager:
- self._register_chat_history(user_id, session_id)
-
- chat_history = self.chat_history_manager[user_id]
- if history is not None:
- chat_history.chat_history = history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *chat_history.chat_history,
- {"role": "user", "content": query},
- ]
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_system_prompt: {system_prompt}"
- )
- yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n"
- # Generate response with custom prompt
- past_key_values = None
- response_stream = None
- if self.config.enable_activation_memory:
- # Handle activation memory (copy MOSCore logic)
- for mem_cube_id, mem_cube in self.mem_cubes.items():
- if mem_cube.act_mem and mem_cube_id == cube_id:
- kv_cache = next(iter(mem_cube.act_mem.get_all()), None)
- past_key_values = (
- kv_cache.memory if (kv_cache and hasattr(kv_cache, "memory")) else None
- )
- if past_key_values is not None:
- logger.info("past_key_values is not None will apply to chat")
- else:
- logger.info("past_key_values is None will not apply to chat")
- break
- if self.config.chat_model.backend == "huggingface":
- response_stream = self.chat_llm.generate_stream(
- current_messages, past_key_values=past_key_values
- )
- elif self.config.chat_model.backend == "vllm":
- response_stream = self.chat_llm.generate_stream(current_messages)
- else:
- if self.config.chat_model.backend in ["huggingface", "vllm", "openai"]:
- response_stream = self.chat_llm.generate_stream(current_messages)
- else:
- response_stream = self.chat_llm.generate(current_messages)
-
- time_end = time.time()
- chat_time_end = time.time()
- logger.info(
- f"time chat: chat time user_id: {user_id} time is: {chat_time_end - search_time_end}"
- )
- # Simulate streaming output with proper reference handling using tiktoken
-
- # Initialize buffer for streaming
- buffer = ""
- full_response = ""
- token_count = 0
- # Use tiktoken for proper token-based chunking
- if self.config.chat_model.backend not in ["huggingface", "vllm", "openai"]:
- # For non-huggingface backends, we need to collect the full response first
- full_response_text = ""
- for chunk in response_stream:
- if chunk in ["", ""]:
- continue
- full_response_text += chunk
- response_stream = self._chunk_response_with_tiktoken(full_response_text, chunk_size=5)
- for chunk in response_stream:
- if chunk in ["", ""]:
- continue
- token_count += 1
- buffer += chunk
- full_response += chunk
-
- # Process buffer to ensure complete reference tags
- processed_chunk, remaining_buffer = process_streaming_references_complete(buffer)
-
- if processed_chunk:
- chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
- yield chunk_data
- buffer = remaining_buffer
-
- # Process any remaining buffer
- if buffer:
- processed_chunk, remaining_buffer = process_streaming_references_complete(buffer)
- if processed_chunk:
- chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
- yield chunk_data
-
- # set kvcache improve speed
- speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1)
- total_time = round(float(time_end - time_start), 1)
-
- yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n"
- # get further suggestion
- current_messages.append({"role": "assistant", "content": full_response})
- further_suggestion = self._get_further_suggestion(current_messages)
- logger.info(f"further_suggestion: {further_suggestion}")
- yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n"
- yield f"data: {json.dumps({'type': 'end'})}\n\n"
-
- # Asynchronous processing of logs, notifications and memory additions
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
-
- def get_all(
- self,
- user_id: str,
- memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"],
- mem_cube_ids: list[str] | None = None,
- ) -> list[dict[str, Any]]:
- """Get all memory items for a user.
-
- Args:
- user_id (str): The ID of the user.
- cube_id (str | None, optional): The ID of the cube. Defaults to None.
- memory_type (Literal["text_mem", "act_mem", "param_mem"]): The type of memory to get.
-
- Returns:
- list[dict[str, Any]]: A list of memory items with cube_id and memories structure.
- """
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- time_start = time.time()
- memory_list = super().get_all(
- mem_cube_id=mem_cube_ids[0] if mem_cube_ids else None, user_id=user_id
- )[memory_type]
- get_all_time_end = time.time()
- logger.info(
- f"time get_all: get_all time user_id: {user_id} time is: {get_all_time_end - time_start}"
- )
- reformat_memory_list = []
- if memory_type == "text_mem":
- for memory in memory_list:
- memories = remove_embedding_recursive(memory["memories"])
- custom_type_ratios = {
- "WorkingMemory": 0.20,
- "LongTermMemory": 0.40,
- "UserMemory": 0.40,
- }
- tree_result, node_type_count = convert_graph_to_tree_forworkmem(
- memories, target_node_count=200, type_ratios=custom_type_ratios
- )
- # Ensure all node IDs are unique in the tree structure
- tree_result = ensure_unique_tree_ids(tree_result)
- memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
- children = tree_result["children"]
- children_sort = sort_children_by_memory_type(children)
- tree_result["children"] = children_sort
- memories_filtered["tree_structure"] = tree_result
- reformat_memory_list.append(
- {
- "cube_id": memory["cube_id"],
- "memories": [memories_filtered],
- "memory_statistics": node_type_count,
- }
- )
- elif memory_type == "act_mem":
- memories_list = []
- act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all()
- if act_mem_params:
- memories_data = act_mem_params[0].model_dump()
- records = memories_data.get("records", [])
- for record in records["text_memories"]:
- memories_list.append(
- {
- "id": memories_data["id"],
- "text": record,
- "create_time": records["timestamp"],
- "size": random.randint(1, 20),
- "modify_times": 1,
- }
- )
- reformat_memory_list.append(
- {
- "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0],
- "memories": memories_list,
- }
- )
- elif memory_type == "para_mem":
- act_mem_params = self.mem_cubes[mem_cube_ids[0]].act_mem.get_all()
- logger.info(f"act_mem_params: {act_mem_params}")
- reformat_memory_list.append(
- {
- "cube_id": "xxxxxxxxxxxxxxxx" if not mem_cube_ids else mem_cube_ids[0],
- "memories": act_mem_params[0].model_dump(),
- }
- )
- make_format_time_end = time.time()
- logger.info(
- f"time get_all: make_format time user_id: {user_id} time is: {make_format_time_end - get_all_time_end}"
- )
- return reformat_memory_list
-
- def _get_subgraph(
- self, query: str, mem_cube_id: str, user_id: str | None = None, top_k: int = 5
- ) -> list[dict[str, Any]]:
- result = {"para_mem": [], "act_mem": [], "text_mem": []}
- if self.config.enable_textual_memory and self.mem_cubes[mem_cube_id].text_mem:
- result["text_mem"].append(
- {
- "cube_id": mem_cube_id,
- "memories": self.mem_cubes[mem_cube_id].text_mem.get_relevant_subgraph(
- query, top_k=top_k
- ),
- }
- )
- return result
-
- def get_subgraph(
- self,
- user_id: str,
- query: str,
- mem_cube_ids: list[str] | None = None,
- top_k: int = 20,
- ) -> list[dict[str, Any]]:
- """Get all memory items for a user.
-
- Args:
- user_id (str): The ID of the user.
- cube_id (str | None, optional): The ID of the cube. Defaults to None.
- mem_cube_ids (list[str], optional): The IDs of the cubes. Defaults to None.
-
- Returns:
- list[dict[str, Any]]: A list of memory items with cube_id and memories structure.
- """
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- memory_list = self._get_subgraph(
- query=query, mem_cube_id=mem_cube_ids[0], user_id=user_id, top_k=top_k
- )["text_mem"]
- reformat_memory_list = []
- for memory in memory_list:
- memories = remove_embedding_recursive(memory["memories"])
- custom_type_ratios = {"WorkingMemory": 0.20, "LongTermMemory": 0.40, "UserMemory": 0.4}
- tree_result, node_type_count = convert_graph_to_tree_forworkmem(
- memories, target_node_count=150, type_ratios=custom_type_ratios
- )
- # Ensure all node IDs are unique in the tree structure
- tree_result = ensure_unique_tree_ids(tree_result)
- memories_filtered = filter_nodes_by_tree_ids(tree_result, memories)
- children = tree_result["children"]
- children_sort = sort_children_by_memory_type(children)
- tree_result["children"] = children_sort
- memories_filtered["tree_structure"] = tree_result
- reformat_memory_list.append(
- {
- "cube_id": memory["cube_id"],
- "memories": [memories_filtered],
- "memory_statistics": node_type_count,
- }
- )
-
- return reformat_memory_list
-
- def search(
- self,
- query: str,
- user_id: str,
- install_cube_ids: list[str] | None = None,
- top_k: int = 10,
- mode: Literal["fast", "fine"] = "fast",
- session_id: str | None = None,
- ):
- """Search memories for a specific user."""
-
- # Load user cubes if not already loaded
- time_start = time.time()
- self._load_user_cubes(user_id, self.default_cube_config)
- load_user_cubes_time_end = time.time()
- logger.info(
- f"time search: load_user_cubes time user_id: {user_id} time is: {load_user_cubes_time_end - time_start}"
- )
- search_result = super().search(
- query, user_id, install_cube_ids, top_k, mode=mode, session_id=session_id
- )
- search_time_end = time.time()
- logger.info(
- f"time search: search text_mem time user_id: {user_id} time is: {search_time_end - load_user_cubes_time_end}"
- )
- text_memory_list = search_result["text_mem"]
- reformat_memory_list = []
- for memory in text_memory_list:
- memories_list = []
- for data in memory["memories"]:
- memories = data.model_dump()
- memories["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["embedding"] = []
- memories["metadata"]["sources"] = []
- memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["id"] = memories["id"]
- memories["metadata"]["memory"] = memories["memory"]
- memories_list.append(memories)
- reformat_memory_list.append({"cube_id": memory["cube_id"], "memories": memories_list})
- logger.info(f"search memory list is : {reformat_memory_list}")
- search_result["text_mem"] = reformat_memory_list
-
- pref_memory_list = search_result["pref_mem"]
- reformat_pref_memory_list = []
- for memory in pref_memory_list:
- memories_list = []
- for data in memory["memories"]:
- memories = data.model_dump()
- memories["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["embedding"] = []
- memories["metadata"]["sources"] = []
- memories["metadata"]["ref_id"] = f"[{memories['id'].split('-')[0]}]"
- memories["metadata"]["id"] = memories["id"]
- memories["metadata"]["memory"] = memories["memory"]
- memories_list.append(memories)
- reformat_pref_memory_list.append(
- {"cube_id": memory["cube_id"], "memories": memories_list}
- )
- search_result["pref_mem"] = reformat_pref_memory_list
- time_end = time.time()
- logger.info(
- f"time search: total time for user_id: {user_id} time is: {time_end - time_start}"
- )
- return search_result
-
- def add(
- self,
- user_id: str,
- messages: MessageList | None = None,
- memory_content: str | None = None,
- doc_path: str | None = None,
- mem_cube_id: str | None = None,
- source: str | None = None,
- user_profile: bool = False,
- session_id: str | None = None,
- task_id: str | None = None, # Add task_id parameter
- ):
- """Add memory for a specific user."""
-
- # Load user cubes if not already loaded
- self._load_user_cubes(user_id, self.default_cube_config)
- result = super().add(
- messages,
- memory_content,
- doc_path,
- mem_cube_id,
- user_id,
- session_id=session_id,
- task_id=task_id,
- )
- if user_profile:
- try:
- user_interests = memory_content.split("'userInterests': '")[1].split("', '")[0]
- user_interests = user_interests.replace(",", " ")
- user_profile_memories = self.mem_cubes[
- mem_cube_id
- ].text_mem.internet_retriever.retrieve_from_internet(query=user_interests, top_k=5)
- for memory in user_profile_memories:
- self.mem_cubes[mem_cube_id].text_mem.add(memory)
- except Exception as e:
- logger.error(
- f"Failed to retrieve user profile: {e}, memory_content: {memory_content}"
- )
-
- return result
-
- def list_users(self) -> list:
- """List all registered users."""
- return self.user_manager.list_users()
-
- def get_user_info(self, user_id: str) -> dict:
- """Get user information including accessible cubes."""
- # Use MOSCore's built-in user validation
- # Validate user access
- self._validate_user_access(user_id)
-
- result = super().get_user_info()
-
- return result
-
- def share_cube_with_user(self, cube_id: str, owner_user_id: str, target_user_id: str) -> bool:
- """Share a cube with another user."""
- # Use MOSCore's built-in cube access validation
- self._validate_cube_access(owner_user_id, cube_id)
-
- result = super().share_cube_with_user(cube_id, target_user_id)
-
- return result
-
- def clear_user_chat_history(self, user_id: str) -> None:
- """Clear chat history for a specific user."""
- # Validate user access
- self._validate_user_access(user_id)
-
- super().clear_messages(user_id)
-
- def update_user_config(self, user_id: str, config: MOSConfig) -> bool:
- """Update user configuration.
-
- Args:
- user_id (str): The user ID.
- config (MOSConfig): The new configuration.
-
- Returns:
- bool: True if successful, False otherwise.
- """
- try:
- # Save to persistent storage
- success = self.user_manager.save_user_config(user_id, config)
- if success:
- # Update in-memory config
- self.user_configs[user_id] = config
- logger.info(f"Updated configuration for user {user_id}")
-
- return success
- except Exception as e:
- logger.error(f"Failed to update user config for {user_id}: {e}")
- return False
-
- def get_user_config(self, user_id: str) -> MOSConfig | None:
- """Get user configuration.
-
- Args:
- user_id (str): The user ID.
-
- Returns:
- MOSConfig | None: The user's configuration or None if not found.
- """
- return self.user_manager.get_user_config(user_id)
-
- def get_active_user_count(self) -> int:
- """Get the number of active user configurations in memory."""
- return len(self.user_configs)
-
- def get_user_instance_info(self) -> dict[str, Any]:
- """Get information about user configurations in memory."""
- return {
- "active_instances": len(self.user_configs),
- "max_instances": self.max_user_instances,
- "user_ids": list(self.user_configs.keys()),
- "lru_order": list(self.user_configs.keys()), # OrderedDict maintains insertion order
- }
diff --git a/src/memos/mem_os/product_server.py b/src/memos/mem_os/product_server.py
deleted file mode 100644
index 80aefea85..000000000
--- a/src/memos/mem_os/product_server.py
+++ /dev/null
@@ -1,457 +0,0 @@
-import asyncio
-import time
-
-from datetime import datetime
-from typing import Literal
-
-from memos.context.context import ContextThread
-from memos.llms.base import BaseLLM
-from memos.log import get_logger
-from memos.mem_cube.navie import NaiveMemCube
-from memos.mem_os.product import _format_mem_block
-from memos.mem_reader.base import BaseMemReader
-from memos.memories.textual.item import TextualMemoryItem
-from memos.templates.mos_prompts import (
- get_memos_prompt,
-)
-from memos.types import MessageList
-
-
-logger = get_logger(__name__)
-
-
-class MOSServer:
- def __init__(
- self,
- mem_reader: BaseMemReader | None = None,
- llm: BaseLLM | None = None,
- online_bot: bool = False,
- ):
- self.mem_reader = mem_reader
- self.chat_llm = llm
- self.online_bot = online_bot
-
- def chat(
- self,
- query: str,
- user_id: str,
- cube_id: str | None = None,
- mem_cube: NaiveMemCube | None = None,
- history: MessageList | None = None,
- base_prompt: str | None = None,
- internet_search: bool = False,
- moscube: bool = False,
- top_k: int = 10,
- threshold: float = 0.5,
- session_id: str | None = None,
- ) -> str:
- """
- Chat with LLM with memory references and complete response.
- """
- time_start = time.time()
- memories_result = mem_cube.text_mem.search(
- query=query,
- user_name=cube_id,
- top_k=top_k,
- mode="fine",
- manual_close_internet=not internet_search,
- moscube=moscube,
- info={
- "user_id": user_id,
- "session_id": session_id,
- "chat_history": history,
- },
- )
-
- memories_list = []
- if memories_result:
- memories_list = self._filter_memories_by_threshold(memories_result, threshold)
- new_memories_list = []
- for m in memories_list:
- m.metadata.embedding = []
- new_memories_list.append(m)
- memories_list = new_memories_list
- system_prompt = self._build_system_prompt(memories_list, base_prompt)
-
- history_info = []
- if history:
- history_info = history[-20:]
- current_messages = [
- {"role": "system", "content": system_prompt},
- *history_info,
- {"role": "user", "content": query},
- ]
- response = self.chat_llm.generate(current_messages)
- time_end = time.time()
- self._start_post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- session_id=session_id,
- query=query,
- full_response=response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=0.0,
- current_messages=current_messages,
- mem_cube=mem_cube,
- history=history,
- )
- return response, memories_list
-
- def add(
- self,
- user_id: str,
- cube_id: str,
- mem_cube: NaiveMemCube,
- messages: MessageList,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> list[str]:
- memories = self.mem_reader.get_memory(
- [messages],
- type="chat",
- info={
- "user_id": user_id,
- "session_id": session_id,
- "chat_history": history,
- },
- )
- flattened_memories = [mm for m in memories for mm in m]
- mem_id_list: list[str] = mem_cube.text_mem.add(
- flattened_memories,
- user_name=cube_id,
- )
- return mem_id_list
-
- def search(
- self,
- user_id: str,
- cube_id: str,
- session_id: str | None = None,
- ) -> None:
- NotImplementedError("Not implemented")
-
- def _filter_memories_by_threshold(
- self,
- memories: list[TextualMemoryItem],
- threshold: float = 0.30,
- min_num: int = 3,
- memory_type: Literal["OuterMemory"] = "OuterMemory",
- ) -> list[TextualMemoryItem]:
- """
- Filter memories by threshold and type, at least min_num memories for Non-OuterMemory.
- Args:
- memories: list[TextualMemoryItem],
- threshold: float,
- min_num: int,
- memory_type: Literal["OuterMemory"],
- Returns:
- list[TextualMemoryItem]
- """
- sorted_memories = sorted(memories, key=lambda m: m.metadata.relativity, reverse=True)
- filtered_person = [m for m in memories if m.metadata.memory_type != memory_type]
- filtered_outer = [m for m in memories if m.metadata.memory_type == memory_type]
- filtered = []
- per_memory_count = 0
- for m in sorted_memories:
- if m.metadata.relativity >= threshold:
- if m.metadata.memory_type != memory_type:
- per_memory_count += 1
- filtered.append(m)
- if len(filtered) < min_num:
- filtered = filtered_person[:min_num] + filtered_outer[:min_num]
- else:
- if per_memory_count < min_num:
- filtered += filtered_person[per_memory_count:min_num]
- filtered_memory = sorted(filtered, key=lambda m: m.metadata.relativity, reverse=True)
- return filtered_memory
-
- def _build_base_system_prompt(
- self,
- base_prompt: str | None = None,
- tone: str = "friendly",
- verbosity: str = "mid",
- mode: str = "enhance",
- ) -> str:
- """
- Build base system prompt without memory references.
- """
- now = datetime.now()
- formatted_date = now.strftime("%Y-%m-%d (%A)")
- sys_body = get_memos_prompt(date=formatted_date, tone=tone, verbosity=verbosity, mode=mode)
- prefix = (base_prompt.strip() + "\n\n") if base_prompt else ""
- return prefix + sys_body
-
- def _build_system_prompt(
- self,
- memories: list[TextualMemoryItem] | list[str] | None = None,
- base_prompt: str | None = None,
- **kwargs,
- ) -> str:
- """Build system prompt with optional memories context."""
- if base_prompt is None:
- base_prompt = (
- "You are a knowledgeable and helpful AI assistant. "
- "You have access to conversation memories that help you provide more personalized responses. "
- "Use the memories to understand the user's context, preferences, and past interactions. "
- "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
- )
-
- memory_context = ""
- if memories:
- memory_list = []
- for i, memory in enumerate(memories, 1):
- if isinstance(memory, TextualMemoryItem):
- text_memory = memory.memory
- else:
- if not isinstance(memory, str):
- logger.error("Unexpected memory type.")
- text_memory = memory
- memory_list.append(f"{i}. {text_memory}")
- memory_context = "\n".join(memory_list)
-
- if "{memories}" in base_prompt:
- return base_prompt.format(memories=memory_context)
- elif base_prompt and memories:
- # For backward compatibility, append memories if no placeholder is found
- memory_context_with_header = "\n\n## Memories:\n" + memory_context
- return base_prompt + memory_context_with_header
- return base_prompt
-
- def _build_memory_context(
- self,
- memories_all: list[TextualMemoryItem],
- mode: str = "enhance",
- ) -> str:
- """
- Build memory context to be included in user message.
- """
- if not memories_all:
- return ""
-
- mem_block_o, mem_block_p = _format_mem_block(memories_all)
-
- if mode == "enhance":
- return (
- "# Memories\n## PersonalMemory (ordered)\n"
- + mem_block_p
- + "\n## OuterMemory (ordered)\n"
- + mem_block_o
- + "\n\n"
- )
- else:
- mem_block = mem_block_o + "\n" + mem_block_p
- return "# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + mem_block + "\n\n"
-
- def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]:
- """
- Extract reference information from the response and return clean text.
-
- Args:
- response (str): The complete response text.
-
- Returns:
- tuple[str, list[dict]]: A tuple containing:
- - clean_text: Text with reference markers removed
- - references: List of reference information
- """
- import re
-
- try:
- references = []
- # Pattern to match [refid:memoriesID]
- pattern = r"\[(\d+):([^\]]+)\]"
-
- matches = re.findall(pattern, response)
- for ref_number, memory_id in matches:
- references.append({"memory_id": memory_id, "reference_number": int(ref_number)})
-
- # Remove all reference markers from the text to get clean text
- clean_text = re.sub(pattern, "", response)
-
- # Clean up any extra whitespace that might be left after removing markers
- clean_text = re.sub(r"\s+", " ", clean_text).strip()
-
- return clean_text, references
- except Exception as e:
- logger.error(f"Error extracting references from response: {e}", exc_info=True)
- return response, []
-
- async def _post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- mem_cube: NaiveMemCube | None = None,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions
- """
- try:
- logger.info(
- f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
- )
- logger.info(f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}")
-
- clean_response, extracted_references = self._extract_references_from_response(
- full_response
- )
- logger.info(f"Extracted {len(extracted_references)} references from response")
-
- # Send chat report notifications asynchronously
- if self.online_bot:
- try:
- from memos.memos_tools.notification_utils import (
- send_online_bot_notification_async,
- )
-
- # Prepare notification data
- chat_data = {
- "query": query,
- "user_id": user_id,
- "cube_id": cube_id,
- "system_prompt": system_prompt,
- "full_response": full_response,
- }
-
- system_data = {
- "references": extracted_references,
- "time_start": time_start,
- "time_end": time_end,
- "speed_improvement": speed_improvement,
- }
-
- emoji_config = {"chat": "๐ฌ", "system_info": "๐"}
-
- await send_online_bot_notification_async(
- online_bot=self.online_bot,
- header_name="MemOS Chat Report",
- sub_title_name="chat_with_references",
- title_color="#00956D",
- other_data1=chat_data,
- other_data2=system_data,
- emoji=emoji_config,
- )
- except Exception as e:
- logger.warning(f"Failed to send chat notification (async): {e}")
-
- self.add(
- user_id=user_id,
- cube_id=cube_id,
- mem_cube=mem_cube,
- session_id=session_id,
- history=history,
- messages=[
- {
- "role": "user",
- "content": query,
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- {
- "role": "assistant",
- "content": clean_response, # Store clean text without reference markers
- "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
- },
- ],
- )
-
- logger.info(f"Post-chat processing completed for user {user_id}")
-
- except Exception as e:
- logger.error(f"Error in post-chat processing for user {user_id}: {e}", exc_info=True)
-
- def _start_post_chat_processing(
- self,
- user_id: str,
- cube_id: str,
- query: str,
- full_response: str,
- system_prompt: str,
- time_start: float,
- time_end: float,
- speed_improvement: float,
- current_messages: list,
- mem_cube: NaiveMemCube | None = None,
- session_id: str | None = None,
- history: MessageList | None = None,
- ) -> None:
- """
- Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
- """
-
- def run_async_in_thread():
- """Running asynchronous tasks in a new thread"""
- try:
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- try:
- loop.run_until_complete(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- mem_cube=mem_cube,
- session_id=session_id,
- history=history,
- )
- )
- finally:
- loop.close()
- except Exception as e:
- logger.error(
- f"Error in thread-based post-chat processing for user {user_id}: {e}",
- exc_info=True,
- )
-
- try:
- # Try to get the current event loop
- asyncio.get_running_loop()
- # Create task and store reference to prevent garbage collection
- task = asyncio.create_task(
- self._post_chat_processing(
- user_id=user_id,
- cube_id=cube_id,
- query=query,
- full_response=full_response,
- system_prompt=system_prompt,
- time_start=time_start,
- time_end=time_end,
- speed_improvement=speed_improvement,
- current_messages=current_messages,
- )
- )
- # Add exception handling for the background task
- task.add_done_callback(
- lambda t: (
- logger.error(
- f"Error in background post-chat processing for user {user_id}: {t.exception()}",
- exc_info=True,
- )
- if t.exception()
- else None
- )
- )
- except RuntimeError:
- # No event loop, run in a new thread with context propagation
- thread = ContextThread(
- target=run_async_in_thread,
- name=f"PostChatProcessing-{user_id}",
- # Set as a daemon thread to avoid blocking program exit
- daemon=True,
- )
- thread.start()
diff --git a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
index ec431c253..671190e6f 100644
--- a/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
+++ b/src/memos/mem_scheduler/general_modules/init_components_for_scheduler.py
@@ -47,13 +47,14 @@ def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
graph_db_backend_map = {
"neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
"neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
"polardb": APIConfig.get_polardb_config(user_id=user_id),
"postgres": APIConfig.get_postgres_config(user_id=user_id),
}
# Support both GRAPH_DB_BACKEND and legacy NEO4J_BACKEND env vars
- graph_db_backend = os.getenv("GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "nebular")).lower()
+ graph_db_backend = os.getenv(
+ "GRAPH_DB_BACKEND", os.getenv("NEO4J_BACKEND", "neo4j-community")
+ ).lower()
return GraphDBConfigFactory.model_validate(
{
"backend": graph_db_backend,
diff --git a/tests/api/test_product_router.py b/tests/api/test_product_router.py
deleted file mode 100644
index 857b290c5..000000000
--- a/tests/api/test_product_router.py
+++ /dev/null
@@ -1,422 +0,0 @@
-"""
-Unit tests for product_router input/output format validation.
-
-This module tests that the product_router endpoints correctly validate
-input request formats and return properly formatted responses.
-"""
-
-from unittest.mock import Mock, patch
-
-import pytest
-
-from fastapi.testclient import TestClient
-
-# Patch the MOS_PRODUCT_INSTANCE directly after import
-# Patch MOS_PRODUCT_INSTANCE and MOSProduct so we can test the FastAPI router
-# without initializing the full MemOS product stack.
-import memos.api.routers.product_router as pr_module
-
-
-_mock_mos_instance = Mock()
-pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance
-pr_module.get_mos_product_instance = lambda: _mock_mos_instance
-with patch("memos.mem_os.product.MOSProduct", return_value=_mock_mos_instance):
- from memos.api import product_api
-
-
-@pytest.fixture(scope="module")
-def mock_mos_product_instance():
- """Mock get_mos_product_instance for all tests."""
- # Ensure the mock is set
- pr_module.MOS_PRODUCT_INSTANCE = _mock_mos_instance
- pr_module.get_mos_product_instance = lambda: _mock_mos_instance
- yield product_api.app, _mock_mos_instance
-
-
-@pytest.fixture
-def client(mock_mos_product_instance):
- """Create test client for product_api."""
- app, _ = mock_mos_product_instance
- return TestClient(app)
-
-
-@pytest.fixture
-def mock_mos_product(mock_mos_product_instance):
- """Get the mocked MOSProduct instance."""
- _, mock_instance = mock_mos_product_instance
- # Ensure get_mos_product_instance returns this mock
- import memos.api.routers.product_router as pr_module
-
- pr_module.get_mos_product_instance = lambda: mock_instance
- pr_module.MOS_PRODUCT_INSTANCE = mock_instance
- return mock_instance
-
-
-@pytest.fixture(autouse=True)
-def setup_mock_mos_product(mock_mos_product):
- """Set up default return values for MOSProduct methods."""
- # Set up default return values for methods
- mock_mos_product.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_mos_product.add.return_value = None
- mock_mos_product.chat.return_value = ("test response", [])
- mock_mos_product.chat_with_references.return_value = iter(
- ['data: {"type": "content", "data": "test"}\n\n']
- )
- # Ensure get_all and get_subgraph return proper list format (MemoryResponse expects list)
- default_memory_result = [{"cube_id": "test_cube", "memories": []}]
- mock_mos_product.get_all.return_value = default_memory_result
- mock_mos_product.get_subgraph.return_value = default_memory_result
- mock_mos_product.get_suggestion_query.return_value = ["suggestion1", "suggestion2"]
- # Ensure get_mos_product_instance returns the mock
- import memos.api.routers.product_router as pr_module
-
- pr_module.get_mos_product_instance = lambda: mock_mos_product
-
-
-class TestProductRouterSearch:
- """Test /search endpoint input/output format."""
-
- def test_search_valid_input_output(self, mock_mos_product, client):
- """Test search endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- "top_k": 10,
- }
-
- response = client.post("/product/search", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], dict)
-
- # Verify MOSProduct.search was called with correct parameters
- mock_mos_product.search.assert_called_once()
- call_kwargs = mock_mos_product.search.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_search_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test search endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/search", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_search_response_format(self, mock_mos_product, client):
- """Test search endpoint returns SearchResponse format."""
- mock_mos_product.search.return_value = {
- "text_mem": [{"cube_id": "test_cube", "memories": []}],
- "act_mem": [],
- "para_mem": [],
- }
-
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- }
-
- response = client.post("/product/search", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Search completed successfully"
- assert isinstance(data["data"], dict)
- assert "text_mem" in data["data"]
-
-
-class TestProductRouterAdd:
- """Test /add endpoint input/output format."""
-
- def test_add_valid_input_output(self, mock_mos_product, client):
- """Test add endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "memory_content": "test memory content",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/add", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert data["data"] is None # SimpleResponse has None data
-
- # Verify MOSProduct.add was called with correct parameters
- mock_mos_product.add.assert_called_once()
- call_kwargs = mock_mos_product.add.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["memory_content"] == "test memory content"
-
- def test_add_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test add endpoint with missing required field."""
- request_data = {
- "memory_content": "test memory content",
- }
-
- response = client.post("/product/add", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_add_response_format(self, mock_mos_product, client):
- """Test add endpoint returns SimpleResponse format."""
- request_data = {
- "user_id": "test_user",
- "memory_content": "test memory content",
- }
-
- response = client.post("/product/add", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Memory created successfully"
- assert data["data"] is None
-
-
-class TestProductRouterChatComplete:
- """Test /chat/complete endpoint input/output format."""
-
- def test_chat_complete_valid_input_output(self, mock_mos_product, client):
- """Test chat/complete endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "message" in data
- assert "data" in data
- assert isinstance(data["data"], dict)
- assert "response" in data["data"]
- assert "references" in data["data"]
-
- # Verify MOSProduct.chat was called with correct parameters
- mock_mos_product.chat.assert_called_once()
- call_kwargs = mock_mos_product.chat.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_chat_complete_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test chat/complete endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_chat_complete_response_format(self, mock_mos_product, client):
- """Test chat/complete endpoint returns correct format."""
- mock_mos_product.chat.return_value = ("test response", [{"id": "ref1"}])
-
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- }
-
- response = client.post("/product/chat/complete", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Chat completed successfully"
- assert isinstance(data["data"]["response"], str)
- assert isinstance(data["data"]["references"], list)
-
-
-class TestProductRouterChat:
- """Test /chat endpoint input/output format (SSE stream)."""
-
- def test_chat_valid_input_output(self, mock_mos_product, client):
- """Test chat endpoint with valid input returns SSE stream."""
- request_data = {
- "user_id": "test_user",
- "query": "test query",
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/chat", json=request_data)
-
- assert response.status_code == 200
- assert "text/event-stream" in response.headers["content-type"]
-
- # Verify MOSProduct.chat_with_references was called
- mock_mos_product.chat_with_references.assert_called_once()
- call_kwargs = mock_mos_product.chat_with_references.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["query"] == "test query"
-
- def test_chat_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test chat endpoint with missing required field."""
- request_data = {
- "query": "test query",
- }
-
- response = client.post("/product/chat", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
-
-class TestProductRouterSuggestions:
- """Test /suggestions endpoint input/output format."""
-
- def test_suggestions_valid_input_output(self, mock_mos_product, client):
- """Test suggestions endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "mem_cube_id": "test_cube",
- "language": "zh",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], dict)
- assert "query" in data["data"]
-
- # Verify MOSProduct.get_suggestion_query was called
- mock_mos_product.get_suggestion_query.assert_called_once()
- call_kwargs = mock_mos_product.get_suggestion_query.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
-
- def test_suggestions_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test suggestions endpoint with missing required field."""
- request_data = {
- "mem_cube_id": "test_cube",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_suggestions_response_format(self, mock_mos_product, client):
- """Test suggestions endpoint returns SuggestionResponse format."""
- mock_mos_product.get_suggestion_query.return_value = [
- "suggestion1",
- "suggestion2",
- "suggestion3",
- ]
-
- request_data = {
- "user_id": "test_user",
- "mem_cube_id": "test_cube",
- "language": "en",
- }
-
- response = client.post("/product/suggestions", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Suggestions retrieved successfully"
- assert isinstance(data["data"], dict)
- assert isinstance(data["data"]["query"], list)
-
-
-class TestProductRouterGetAll:
- """Test /get_all endpoint input/output format."""
-
- def test_get_all_valid_input_output(self, mock_mos_product, client):
- """Test get_all endpoint with valid input returns correct output format."""
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
-
- # Validate response structure
- assert "code" in data
- assert "message" in data
- assert "data" in data
- assert data["code"] == 200
- assert isinstance(data["data"], list)
-
- # Verify MOSProduct.get_all was called
- mock_mos_product.get_all.assert_called_once()
- call_kwargs = mock_mos_product.get_all.call_args[1]
- assert call_kwargs["user_id"] == "test_user"
- assert call_kwargs["memory_type"] == "text_mem"
-
- def test_get_all_with_search_query(self, mock_mos_product, client):
- """Test get_all endpoint with search_query uses get_subgraph."""
- # Reset mock call counts
- mock_mos_product.get_all.reset_mock()
- mock_mos_product.get_subgraph.reset_mock()
-
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- "search_query": "test query",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- # Verify get_subgraph was called instead of get_all
- mock_mos_product.get_subgraph.assert_called_once()
- mock_mos_product.get_all.assert_not_called()
-
- def test_get_all_invalid_input_missing_user_id(self, mock_mos_product, client):
- """Test get_all endpoint with missing required field."""
- request_data = {
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- # Should return validation error
- assert response.status_code == 422
-
- def test_get_all_response_format(self, mock_mos_product, client):
- """Test get_all endpoint returns MemoryResponse format."""
- mock_mos_product.get_all.return_value = [{"cube_id": "test_cube", "memories": []}]
-
- request_data = {
- "user_id": "test_user",
- "memory_type": "text_mem",
- }
-
- response = client.post("/product/get_all", json=request_data)
-
- assert response.status_code == 200
- data = response.json()
- assert data["message"] == "Memories retrieved successfully"
- assert isinstance(data["data"], list)
- assert len(data["data"]) > 0
diff --git a/tests/api/test_start_api.py b/tests/api/test_start_api.py
deleted file mode 100644
index e1ffcd74b..000000000
--- a/tests/api/test_start_api.py
+++ /dev/null
@@ -1,401 +0,0 @@
-from unittest.mock import Mock, patch
-
-import pytest
-
-from fastapi.testclient import TestClient
-
-from memos.api.start_api import app
-from memos.mem_user.user_manager import UserRole
-
-
-client = TestClient(app)
-
-# Mock data
-MOCK_MESSAGE = {"role": "user", "content": "test message"}
-MOCK_MEMORY_CREATE = {
- "messages": [MOCK_MESSAGE],
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_MEMORY_CONTENT = {
- "memory_content": "test memory content",
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_DOC_PATH = {"doc_path": "/path/to/doc", "mem_cube_id": "test_cube", "user_id": "test_user"}
-MOCK_SEARCH_REQUEST = {
- "query": "test query",
- "user_id": "test_user",
- "install_cube_ids": ["test_cube"],
-}
-MOCK_MEMCUBE_REGISTER = {
- "mem_cube_name_or_path": "test_cube_path",
- "mem_cube_id": "test_cube",
- "user_id": "test_user",
-}
-MOCK_CHAT_REQUEST = {"query": "test chat query", "user_id": "test_user"}
-MOCK_USER_CREATE = {"user_id": "test_user", "user_name": "Test User", "role": "USER"}
-MOCK_CUBE_SHARE = {"target_user_id": "target_user"}
-MOCK_CONFIG = {
- "user_id": "test_user",
- "session_id": "test_session",
- "enable_textual_memory": True,
- "enable_activation_memory": False,
- "top_k": 5,
- "chat_model": {
- "backend": "openai",
- "config": {
- "model_name_or_path": "gpt-3.5-turbo",
- "api_key": "test_key",
- "temperature": 0.7,
- "api_base": "https://api.openai.com/v1",
- },
- },
-}
-
-
-@pytest.fixture
-def mock_mos():
- """Mock MOS instance for testing."""
- with patch("memos.api.start_api.get_mos_instance") as mock_get_mos:
- # Create a mock MOS instance
- mock_instance = Mock()
-
- # Set up default return values for methods
- mock_instance.search.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_instance.get_all.return_value = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_instance.get.return_value = {"memory": "test memory"}
- mock_instance.chat.return_value = "test response"
- mock_instance.list_users.return_value = []
- mock_instance.get_user_info.return_value = {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "accessible_cubes": [],
- }
- mock_instance.create_user.return_value = "test_user"
- mock_instance.share_cube_with_user.return_value = True
-
- # Configure the mock to return our mock instance
- mock_get_mos.return_value = mock_instance
-
- yield mock_instance
-
-
-def test_configure_error(mock_mos):
- """Test configuration endpoint with error."""
- with patch("memos.api.start_api.MOS_INSTANCE", None):
- response = client.post("/configure", json={})
- assert response.status_code == 422 # FastAPI validation error
-
-
-def test_create_user(mock_mos):
- """Test user creation endpoint."""
- response = client.post("/users", json=MOCK_USER_CREATE)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "User created successfully",
- "data": {"user_id": "test_user"},
- }
- mock_mos.create_user.assert_called_once_with(
- user_id="test_user", role=UserRole.USER, user_name="Test User"
- )
-
-
-def test_create_user_validation_error(mock_mos):
- """Test user creation with validation error."""
- mock_mos.create_user.side_effect = ValueError("Invalid user data")
- response = client.post("/users", json=MOCK_USER_CREATE)
- assert response.status_code == 400
- assert "Invalid user data" in response.json()["message"]
-
-
-def test_list_users(mock_mos):
- """Test list users endpoint."""
- # Set up mock to return the expected data structure
- mock_users = [
- {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "created_at": "2023-01-01T00:00:00",
- "is_active": True,
- }
- ]
- mock_mos.list_users.return_value = mock_users
-
- response = client.get("/users")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Users retrieved successfully",
- "data": mock_users,
- }
- mock_mos.list_users.assert_called_once()
-
-
-def test_get_user_info(mock_mos):
- """Test get user info endpoint."""
- # Set up mock to return the expected data structure
- mock_user_info = {
- "user_id": "test_user",
- "user_name": "Test User",
- "role": "user",
- "created_at": "2023-01-01T00:00:00",
- "accessible_cubes": [],
- }
- mock_mos.get_user_info.return_value = mock_user_info
-
- response = client.get("/users/me")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "User info retrieved successfully",
- "data": mock_user_info,
- }
- mock_mos.get_user_info.assert_called_once()
-
-
-def test_register_mem_cube(mock_mos):
- """Test MemCube registration endpoint."""
- response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "MemCube registered successfully",
- "data": None,
- }
- mock_mos.register_mem_cube.assert_called_once_with(
- mem_cube_name_or_path="test_cube_path", mem_cube_id="test_cube", user_id="test_user"
- )
-
-
-def test_register_mem_cube_validation_error(mock_mos):
- """Test MemCube registration with validation error."""
- mock_mos.register_mem_cube.side_effect = ValueError("Invalid MemCube")
- response = client.post("/mem_cubes", json=MOCK_MEMCUBE_REGISTER)
- assert response.status_code == 400
- assert "Invalid MemCube" in response.json()["message"]
-
-
-def test_unregister_mem_cube(mock_mos):
- """Test MemCube unregistration endpoint."""
- response = client.delete("/mem_cubes/test_cube?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "MemCube unregistered successfully",
- "data": None,
- }
- mock_mos.unregister_mem_cube.assert_called_once_with(
- mem_cube_id="test_cube", user_id="test_user"
- )
-
-
-def test_unregister_nonexistent_mem_cube(mock_mos):
- """Test unregistering a non-existent MemCube."""
- mock_mos.unregister_mem_cube.side_effect = ValueError("MemCube not found")
- response = client.delete("/mem_cubes/nonexistent_cube")
- assert response.status_code == 400
- assert "MemCube not found" in response.json()["message"]
-
-
-def test_share_cube(mock_mos):
- """Test cube sharing endpoint."""
- response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Cube shared successfully", "data": None}
- mock_mos.share_cube_with_user.assert_called_once_with("test_cube", "target_user")
-
-
-def test_share_cube_failure(mock_mos):
- """Test cube sharing failure."""
- mock_mos.share_cube_with_user.return_value = False
- response = client.post("/mem_cubes/test_cube/share", json=MOCK_CUBE_SHARE)
- assert response.status_code == 400
- assert "Failed to share cube" in response.json()["message"]
-
-
-@pytest.mark.parametrize(
- "memory_create,expected_calls",
- [
- (MOCK_MEMORY_CREATE, {"messages": [MOCK_MESSAGE]}),
- (MOCK_MEMORY_CONTENT, {"memory_content": "test memory content"}),
- (MOCK_DOC_PATH, {"doc_path": "/path/to/doc"}),
- ],
-)
-def test_add_memory(mock_mos, memory_create, expected_calls):
- """Test adding memories with different types of content."""
- response = client.post("/memories", json=memory_create)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memories added successfully", "data": None}
- mock_mos.add.assert_called_once()
-
-
-def test_add_memory_validation_error(mock_mos):
- """Test adding memory with validation error."""
- response = client.post("/memories", json={})
- assert response.status_code == 400
- assert "must be provided" in response.json()["message"]
-
-
-def test_get_all_memories(mock_mos):
- """Test get all memories endpoint."""
- mock_results = {
- "text_mem": [{"cube_id": "test_cube", "memories": []}],
- "act_mem": [],
- "para_mem": [],
- }
- mock_mos.get_all.return_value = mock_results
-
- response = client.get("/memories")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Memories retrieved successfully",
- "data": mock_results,
- }
- mock_mos.get_all.assert_called_once_with(mem_cube_id=None, user_id=None)
-
-
-def test_get_memory(mock_mos):
- """Test get specific memory endpoint."""
- mock_memory = {"memory": "test memory content"}
- mock_mos.get.return_value = mock_memory
-
- response = client.get("/memories/test_cube/test_memory")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Memory retrieved successfully",
- "data": mock_memory,
- }
- mock_mos.get.assert_called_once_with(
- mem_cube_id="test_cube", memory_id="test_memory", user_id=None
- )
-
-
-def test_get_nonexistent_memory(mock_mos):
- """Test getting a non-existent memory."""
- mock_mos.get.side_effect = ValueError("Memory not found")
- response = client.get("/memories/test_cube/nonexistent_memory")
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_search_memories(mock_mos):
- """Test search memories endpoint."""
- # Mock the search method to return a proper result structure
- mock_results = {"text_mem": [], "act_mem": [], "para_mem": []}
- mock_mos.search.return_value = mock_results
-
- # Ensure the search request has all required fields
- search_request = {
- "query": "test query",
- "user_id": "test_user",
- "install_cube_ids": ["test_cube"],
- }
-
- response = client.post("/search", json=search_request)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Search completed successfully",
- "data": mock_results,
- }
- mock_mos.search.assert_called_once_with(
- query="test query", user_id="test_user", install_cube_ids=["test_cube"]
- )
-
-
-def test_update_memory(mock_mos):
- """Test updating a memory endpoint."""
- update_data = {"content": "updated content"}
- response = client.put("/memories/test_cube/test_memory?user_id=test_user", json=update_data)
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memory updated successfully", "data": None}
- mock_mos.update.assert_called_once_with(
- mem_cube_id="test_cube",
- memory_id="test_memory",
- text_memory_item=update_data,
- user_id="test_user",
- )
-
-
-def test_update_nonexistent_memory(mock_mos):
- """Test updating a non-existent memory."""
- mock_mos.update.side_effect = ValueError("Memory not found")
- response = client.put("/memories/test_cube/nonexistent_memory", json={})
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_delete_memory(mock_mos):
- """Test deleting a memory endpoint."""
- response = client.delete("/memories/test_cube/test_memory?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {"code": 200, "message": "Memory deleted successfully", "data": None}
- mock_mos.delete.assert_called_once_with(
- mem_cube_id="test_cube", memory_id="test_memory", user_id="test_user"
- )
-
-
-def test_delete_nonexistent_memory(mock_mos):
- """Test deleting a non-existent memory."""
- mock_mos.delete.side_effect = ValueError("Memory not found")
- response = client.delete("/memories/test_cube/nonexistent_memory")
- assert response.status_code == 400
- assert "Memory not found" in response.json()["message"]
-
-
-def test_delete_all_memories(mock_mos):
- """Test deleting all memories endpoint."""
- response = client.delete("/memories/test_cube?user_id=test_user")
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "All memories deleted successfully",
- "data": None,
- }
- mock_mos.delete_all.assert_called_once_with(mem_cube_id="test_cube", user_id="test_user")
-
-
-def test_delete_all_nonexistent_memories(mock_mos):
- """Test deleting all memories from non-existent MemCube."""
- mock_mos.delete_all.side_effect = ValueError("MemCube not found")
- response = client.delete("/memories/nonexistent_cube")
- assert response.status_code == 400
- assert "MemCube not found" in response.json()["message"]
-
-
-def test_chat(mock_mos):
- """Test chat endpoint."""
- response = client.post("/chat", json=MOCK_CHAT_REQUEST)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Chat response generated",
- "data": "test response",
- }
- mock_mos.chat.assert_called_once_with(query="test chat query", user_id="test_user")
-
-
-def test_chat_without_user_id(mock_mos):
- """Test chat endpoint without user_id."""
- chat_request = {"query": "test chat query"}
- response = client.post("/chat", json=chat_request)
- assert response.status_code == 200
- assert response.json() == {
- "code": 200,
- "message": "Chat response generated",
- "data": "test response",
- }
- mock_mos.chat.assert_called_once_with(query="test chat query", user_id=None)
-
-
-def test_home_redirect():
- """Test home endpoint redirects to docs."""
- response = client.get("/", follow_redirects=False)
- assert response.status_code == 307
- assert response.headers["location"] == "/docs"
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 9750af121..a1e423e4f 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -16,13 +16,13 @@
class TestExportOpenAPI:
"""Test the export_openapi function."""
- @patch("memos.api.start_api.app")
+ @patch("memos.cli.get_openapi_app")
@patch("builtins.open", new_callable=mock_open)
@patch("os.makedirs")
def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app):
"""Test successful OpenAPI export."""
mock_openapi_data = {"openapi": "3.0.0", "info": {"title": "Test API"}}
- mock_app.openapi.return_value = mock_openapi_data
+ mock_app.return_value.openapi.return_value = mock_openapi_data
result = export_openapi("/test/path/openapi.json")
@@ -30,11 +30,11 @@ def test_export_openapi_success(self, mock_makedirs, mock_file, mock_app):
mock_makedirs.assert_called_once_with("/test/path", exist_ok=True)
mock_file.assert_called_once_with("/test/path/openapi.json", "w")
- @patch("memos.api.start_api.app")
+ @patch("memos.cli.get_openapi_app")
@patch("builtins.open", side_effect=OSError("Permission denied"))
def test_export_openapi_error(self, mock_file, mock_app):
"""Test OpenAPI export when file writing fails."""
- mock_app.openapi.return_value = {"test": "data"}
+ mock_app.return_value.openapi.return_value = {"test": "data"}
with pytest.raises(IOError):
export_openapi("/invalid/path/openapi.json")