diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index ffc9dda12..91d695acc 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -183,6 +183,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, + "mode": "mixture", }, ensure_ascii=False, ) @@ -230,6 +231,7 @@ def search(self, query, user_id, top_k): "query": query, "user_id": user_id, "memory_limit_number": top_k, + "mode": "mixture", } ) @@ -311,7 +313,7 @@ def add(self, messages, user_id, iso_date): agent_name=self.agent_id, session_date=iso_date, ) - self.wait_for_completion(response.task_id) + self.wait_for_completion(response.item_id) except Exception as error: print("āŒ Error saving conversation:", error) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index d1bc6efff..6de013313 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -361,8 +361,8 @@ def get_scheduler_config() -> dict[str, Any]: "thread_pool_max_workers": int( os.getenv("MOS_SCHEDULER_THREAD_POOL_MAX_WORKERS", "10") ), - "consume_interval_seconds": int( - os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "3") + "consume_interval_seconds": float( + os.getenv("MOS_SCHEDULER_CONSUME_INTERVAL_SECONDS", "0.01") ), "enable_parallel_dispatch": os.getenv( "MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true" diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 1331094a8..f50d3ad75 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,9 +1,8 @@ -import json import os import traceback from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException @@ -33,11 +32,8 @@ from memos.mem_scheduler.orm_modules.base_model import BaseDBManager from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.general_schemas import ( - API_MIX_SEARCH_LABEL, SearchMode, ) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.prefer_text_memory.config import ( AdderConfigFactory, ExtractorConfigFactory, @@ -54,6 +50,10 @@ ) from memos.reranker.factory import RerankerFactory from memos.templates.instruction_completion import instruct_completion + + +if TYPE_CHECKING: + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler from memos.types import MOSSearchResult, UserContext from memos.vec_dbs.factory import VecDBFactory @@ -154,7 +154,6 @@ def init_server(): # Build component configurations graph_db_config = _build_graph_db_config() - print(graph_db_config) llm_config = _build_llm_config() embedder_config = _build_embedder_config() mem_reader_config = _build_mem_reader_config() @@ -209,22 +208,6 @@ def init_server(): online_bot=False, ) - # Initialize Scheduler - scheduler_config_dict = APIConfig.get_scheduler_config() - scheduler_config = SchedulerConfigFactory( - backend="optimized_scheduler", config=scheduler_config_dict - ) - mem_scheduler = SchedulerFactory.from_config(scheduler_config) - mem_scheduler.initialize_modules( - chat_llm=llm, - process_llm=mem_reader.llm, - db_engine=BaseDBManager.create_default_sqlite_engine(), - ) - mem_scheduler.start() - - # Initialize SchedulerAPIModule - api_module = mem_scheduler.api_module - naive_mem_cube = NaiveMemCube( llm=llm, embedder=embedder, @@ -240,6 +223,23 @@ def init_server(): pref_retriever=pref_retriever, ) + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + ) + mem_scheduler.current_mem_cube = naive_mem_cube + mem_scheduler.start() + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + return ( graph_db, mem_reader, @@ -400,96 +400,12 @@ def mix_search_memories( """ Mix search memories: fast search + async fine search """ - # Get fast memories first - fast_memories = fast_search_memories(search_req, user_context) - - # Check if scheduler and dispatcher are available for async execution - if mem_scheduler and hasattr(mem_scheduler, "dispatcher") and mem_scheduler.dispatcher: - try: - # Create message for async fine search - message_content = { - "search_req": { - "query": search_req.query, - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "top_k": search_req.top_k, - "internet_search": search_req.internet_search, - "moscube": search_req.moscube, - "chat_history": search_req.chat_history, - }, - "user_context": {"mem_cube_id": user_context.mem_cube_id}, - } - - message = ScheduleMessageItem( - item_id=f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}", - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - label=API_MIX_SEARCH_LABEL, - mem_cube=naive_mem_cube, - content=json.dumps(message_content), - timestamp=get_utc_now(), - ) - - # Submit async task - mem_scheduler.dispatcher.submit_message(message) - logger.info(f"Submitted async fine search task for user {search_req.user_id}") - - # Try to get pre-computed fine memories if available - try: - pre_fine_memories = api_module.get_pre_fine_memories( - user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id - ) - if pre_fine_memories: - # Merge fast and pre-computed fine memories - all_memories = fast_memories + pre_fine_memories - # Remove duplicates based on content - seen_contents = set() - unique_memories = [] - for memory in all_memories: - content_key = memory.get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - return unique_memories - except Exception as e: - logger.warning(f"Failed to get pre-computed fine memories: {e}") - - except Exception as e: - logger.error(f"Failed to submit async fine search task: {e}") - # Fall back to synchronous execution - - # Fallback: synchronous fine search - try: - fine_memories = fine_search_memories(search_req, user_context) - - # Merge fast and fine memories - all_memories = fast_memories + fine_memories - - # Remove duplicates based on content - seen_contents = set() - unique_memories = [] - for memory in all_memories: - content_key = memory.get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - - # Sync search data to Redis - try: - api_module.sync_search_data( - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - query=search_req.query, - formatted_memories=unique_memories, - ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}") - - return unique_memories - except Exception as e: - logger.error(f"Fine search failed: {e}") - return fast_memories + formatted_memories = mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + return formatted_memories def fine_search_memories( diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index bc22cfb63..e757f243b 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -15,6 +15,7 @@ DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, @@ -59,6 +60,10 @@ class BaseSchedulerConfig(BaseConfig): default=DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, description="Maximum size of internal message queue when not using Redis", ) + multi_task_running_timeout: int = Field( + default=DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, + description="Default timeout for multi-task running operations in seconds", + ) class GeneralSchedulerConfig(BaseSchedulerConfig): diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 45a39e0de..28ca182e5 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -7,6 +7,7 @@ import http.client import json +import time from typing import Any from urllib.parse import urlparse @@ -364,11 +365,204 @@ def __init__(self): self.UserContext = UserContext self.MessageDict = MessageDict + # Initialize conversation history for continuous conversation support + self.conversation_history = [] + self.current_session_id = None + self.current_user_id = None + self.current_mem_cube_id = None + logger.info("DirectSearchMemoriesAnalyzer initialized successfully") except ImportError as e: logger.error(f"Failed to import modules: {e}") raise + def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", session_id=None): + """ + Start a new conversation session for continuous dialogue. + + Args: + user_id: User ID for the conversation + mem_cube_id: Memory cube ID for the conversation + session_id: Session ID for the conversation (auto-generated if None) + """ + self.current_user_id = user_id + self.current_mem_cube_id = mem_cube_id + self.current_session_id = ( + session_id or f"session_{hash(user_id + mem_cube_id)}_{len(self.conversation_history)}" + ) + self.conversation_history = [] + + logger.info(f"Started conversation session: {self.current_session_id}") + print(f"šŸš€ Started new conversation session: {self.current_session_id}") + print(f" User ID: {self.current_user_id}") + print(f" Mem Cube ID: {self.current_mem_cube_id}") + + def add_to_conversation(self, user_message, assistant_message=None): + """ + Add messages to the current conversation and store them in memory. + + Args: + user_message: User's message content + assistant_message: Assistant's response (optional) + + Returns: + Result from add_memories function + """ + if not self.current_session_id: + raise ValueError("No active conversation session. Call start_conversation() first.") + + # Prepare messages for adding to memory + messages = [{"role": "user", "content": user_message}] + if assistant_message: + messages.append({"role": "assistant", "content": assistant_message}) + + # Add to conversation history + self.conversation_history.extend(messages) + + # Create add request + add_req = self.create_test_add_request( + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + messages=messages, + session_id=self.current_session_id, + ) + + print(f"šŸ’¬ Adding to conversation (Session: {self.current_session_id}):") + print(f" User: {user_message}") + if assistant_message: + print(f" Assistant: {assistant_message}") + + # Add to memory + result = self.add_memories(add_req) + print(" āœ… Added to memory successfully") + + return result + + def search_in_conversation(self, query, mode="fast", top_k=10, include_history=True): + """ + Search memories within the current conversation context. + + Args: + query: Search query + mode: Search mode ("fast", "fine", or "mixture") + top_k: Number of results to return + include_history: Whether to include conversation history in the search + + Returns: + Search results + """ + if not self.current_session_id: + raise ValueError("No active conversation session. Call start_conversation() first.") + + # Prepare chat history if requested + chat_history = self.conversation_history if include_history else None + + # Create search request + search_req = self.create_test_search_request( + query=query, + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + mode=mode, + top_k=top_k, + chat_history=chat_history, + session_id=self.current_session_id, + ) + + print(f"šŸ” Searching in conversation (Session: {self.current_session_id}):") + print(f" Query: {query}") + print(f" Mode: {mode}") + print(f" Top K: {top_k}") + print(f" Include History: {include_history}") + print(f" History Length: {len(self.conversation_history) if chat_history else 0}") + + # Perform search + result = self.search_memories(search_req) + + print(" āœ… Search completed") + if hasattr(result, "data") and result.data: + total_memories = sum( + len(mem_list) for mem_list in result.data.values() if isinstance(mem_list, list) + ) + print(f" šŸ“Š Found {total_memories} total memories") + + return result + + def test_continuous_conversation(self): + """Test continuous conversation functionality""" + print("=" * 80) + print("Testing Continuous Conversation Functionality") + print("=" * 80) + + try: + # Start a conversation + self.start_conversation(user_id="conv_test_user", mem_cube_id="conv_test_cube") + + # Prepare all conversation messages for batch addition + all_messages = [ + { + "role": "user", + "content": "I'm planning a trip to Shanghai for New Year's Eve. What are some good places to visit?", + }, + { + "role": "assistant", + "content": "Shanghai has many great places for New Year's Eve! You could visit the Bund for the countdown, go to a rooftop party, or enjoy fireworks at Disneyland Shanghai. The French Concession also has nice bars and restaurants.", + }, + {"role": "user", "content": "What about food? Any restaurant recommendations?"}, + { + "role": "assistant", + "content": "For New Year's Eve dining in Shanghai, I'd recommend trying some local specialties like xiaolongbao at Din Tai Fung, or for a fancy dinner, you could book at restaurants in the Bund area with great views.", + }, + {"role": "user", "content": "I'm on a budget though. Any cheaper alternatives?"}, + { + "role": "assistant", + "content": "For budget-friendly options, try street food in Yuyuan Garden area, local noodle shops, or food courts in shopping malls. You can also watch the fireworks from free public areas along the Huangpu River.", + }, + ] + + # Add all conversation messages at once + print("\nšŸ“ Adding all conversation messages at once:") + add_req = self.create_test_add_request( + user_id=self.current_user_id, + mem_cube_id=self.current_mem_cube_id, + messages=all_messages, + session_id=self.current_session_id, + ) + + print( + f"šŸ’¬ Adding {len(all_messages)} messages to conversation (Session: {self.current_session_id})" + ) + self.add_memories(add_req) + + # Update conversation history + self.conversation_history.extend(all_messages) + print(" āœ… Added all messages to memory successfully") + + # Test searching within the conversation + print("\nšŸ” Testing search within conversation:") + + # Search for trip-related information + self.search_in_conversation( + query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5 + ) + + # Search for food-related information + self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3) + + # Search without conversation history + self.search_in_conversation( + query="Shanghai travel", mode="mixture", top_k=3, include_history=False + ) + + print("\nāœ… Continuous conversation test completed successfully!") + return True + + except Exception as e: + print(f"āŒ Continuous conversation test failed: {e}") + import traceback + + traceback.print_exc() + return False + def create_test_search_request( self, query="test query", @@ -451,115 +645,19 @@ def create_test_add_request( operation=None, ) - def test_add_memories_basic(self, user_id="test_user_add", mem_cube_id="test_cube_add"): - """Basic add_memories test""" - print("=" * 60) - print("Starting basic add_memories test") - print("=" * 60) - - try: - # Create test request with default messages - add_req = self.create_test_add_request(user_id=user_id, mem_cube_id=mem_cube_id) - - print("Test request created:") - print(f" User ID: {add_req.user_id}") - print(f" Mem Cube ID: {add_req.mem_cube_id}") - print(f" Messages: {add_req.messages}") - print(f" Session ID: {add_req.session_id}") - - # Call add_memories function - print("\nCalling add_memories function...") - result = self.add_memories(add_req) - - print(f"Add result: {result}") - print("Basic add_memories test completed successfully") - return result - - except Exception as e: - print(f"Basic add_memories test failed: {e}") - import traceback - - traceback.print_exc() - return None - - def test_search_memories_basic(self, query: str, mode: str, topk: int): - """Basic search_memories test""" - print("=" * 60) - print("Starting basic search_memories test") - print("=" * 60) - - try: - # Create test request - search_req = self.create_test_search_request( - query=query, - user_id="test_user_id", - mem_cube_id="test_mem_cube_id", - mode=mode, - top_k=topk, - ) - - print("Test request parameters:") - print(f" - query: {search_req.query}") - print(f" - user_id: {search_req.user_id}") - print(f" - mem_cube_id: {search_req.mem_cube_id}") - print(f" - mode: {search_req.mode}") - print(f" - top_k: {search_req.top_k}") - print(f" - internet_search: {search_req.internet_search}") - print(f" - moscube: {search_req.moscube}") - print() - - # Call search_memories function - print("Calling search_memories function...") - result = self.search_memories(search_req) - - print("āœ… Function call successful!") - print(f"Return result type: {type(result)}") - print(f"Return result: {result}") - - # Analyze return result - if hasattr(result, "message"): - print(f"Message: {result.message}") - if hasattr(result, "data"): - print(f"Data type: {type(result.data)}") - if result.data and isinstance(result.data, dict): - for key, value in result.data.items(): - print(f" {key}: {len(value) if isinstance(value, list) else value}") - - return result - - except Exception as e: - print(f"āŒ Test failed: {e}") - import traceback - - print("Detailed error information:") - traceback.print_exc() - return None - def run_all_tests(self): """Run all available tests""" print("šŸš€ Starting comprehensive test suite") print("=" * 80) - # Test add_memories functions (more likely to have dependency issues) - print("\n\nšŸ“ Testing ADD_MEMORIES functions:") - try: - print("\n" + "-" * 40) - self.test_add_memories_basic() - print("āœ… Basic add memories test completed") - except Exception as e: - print(f"āŒ Basic add memories test failed: {e}") - - # Test search_memories functions first (less likely to fail) - print("\nšŸ” Testing SEARCH_MEMORIES functions:") + # Test continuous conversation functionality + print("\nšŸ’¬ Testing CONTINUOUS CONVERSATION functions:") try: - self.test_search_memories_basic( - query="What are some good places to celebrate New Year's Eve in Shanghai?", - mode="fast", - topk=3, - ) - print("āœ… Search memories test completed successfully") + self.test_continuous_conversation() + time.sleep(5) + print("āœ… Continuous conversation test completed successfully") except Exception as e: - print(f"āŒ Search memories test failed: {e}") + print(f"āŒ Continuous conversation test failed: {e}") print("\n" + "=" * 80) print("āœ… All tests completed!") diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index e475ea225..3958ee382 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -502,7 +502,7 @@ def update_activation_memory_periodically( except Exception as e: logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) - async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): """Submit messages to the message queue (either local queue or Redis).""" if isinstance(messages, ScheduleMessageItem): messages = [messages] # transform single message to list @@ -519,7 +519,7 @@ async def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMes if self.use_redis_queue: # Use Redis stream for message queue - await self.redis_add_message_stream(message.to_dict()) + self.redis_add_message_stream(message.to_dict()) logger.info(f"Submitted message to Redis: {message.label} - {message.content}") else: # Use local queue @@ -774,34 +774,6 @@ def unregister_handlers(self, labels: list[str]) -> dict[str, bool]: return self.dispatcher.unregister_handlers(labels) def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, dict]: - """ - Get currently running tasks, optionally filtered by a custom function. - - This method delegates to the dispatcher's get_running_tasks method. - - Args: - filter_func: Optional function to filter tasks. Should accept a RunningTaskItem - and return True if the task should be included in results. - - Returns: - dict[str, dict]: Dictionary mapping task IDs to task information dictionaries. - Each task dict contains: item_id, user_id, mem_cube_id, task_info, - task_name, start_time, end_time, status, result, error_message, messages - - Examples: - # Get all running tasks - all_tasks = scheduler.get_running_tasks() - - # Get tasks for specific user - user_tasks = scheduler.get_running_tasks( - filter_func=lambda task: task.user_id == "user123" - ) - - # Get tasks with specific status - active_tasks = scheduler.get_running_tasks( - filter_func=lambda task: task.status == "running" - ) - """ if not self.dispatcher: logger.warning("Dispatcher is not initialized, returning empty tasks dict") return {} diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py index 6139a895a..bb993de38 100644 --- a/src/memos/mem_scheduler/general_modules/api_misc.py +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -1,115 +1,145 @@ -import threading - from typing import Any from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager +from memos.mem_scheduler.orm_modules.api_redis_model import APIRedisDBManager +from memos.mem_scheduler.schemas.api_schemas import ( + APIMemoryHistoryEntryItem, + APISearchHistoryManager, + TaskRunningStatus, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.item import TextualMemoryItem logger = get_logger(__name__) class SchedulerAPIModule(BaseSchedulerModule): - def __init__(self): + def __init__(self, window_size=5): super().__init__() + self.window_size = window_size + self.search_history_managers: dict[str, APIRedisDBManager] = {} + self.pre_memory_turns = 5 - self.search_history_managers: dict[str, RedisDBManager] = {} - - def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> RedisDBManager: + def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> APIRedisDBManager: """Get or create a Redis manager for search history.""" key = f"search_history:{user_id}:{mem_cube_id}" if key not in self.search_history_managers: - self.search_history_managers[key] = RedisDBManager( - user_id=user_id, mem_cube_id=mem_cube_id + self.search_history_managers[key] = APIRedisDBManager( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=APISearchHistoryManager(window_size=self.window_size), ) return self.search_history_managers[key] def sync_search_data( - self, user_id: str, mem_cube_id: str, query: str, formatted_memories: Any - ) -> None: - """ - Sync search data to Redis, maintaining a list of size 5. + self, + item_id: str, + user_id: str, + mem_cube_id: str, + query: str, + memories: list[TextualMemoryItem], + formatted_memories: Any, + conversation_id: str | None = None, + ) -> Any: + # Get the search history manager + manager = self.get_search_history_manager(user_id, mem_cube_id) + manager.sync_with_redis(size_limit=self.window_size) + + search_history = manager.obj + + # Check if entry with item_id already exists + existing_entry, location = search_history.find_entry_by_item_id(item_id) + + if existing_entry is not None: + # Update existing entry + success = search_history.update_entry_by_item_id( + item_id=item_id, + query=query, + formatted_memories=formatted_memories, + task_status=TaskRunningStatus.COMPLETED, # Use the provided running_status + conversation_id=conversation_id, + memories=memories, + ) - Args: - user_id: User identifier - mem_cube_id: Memory cube identifier - query: Search query string - formatted_memories: Formatted search results - """ - try: - # Get the search history manager - manager = self.get_search_history_manager(user_id, mem_cube_id) - - # Create search data entry - search_entry = { - "query": query, - "formatted_memories": formatted_memories, - "timestamp": threading.current_thread().ident, # Use thread ID as simple timestamp - } - - # Load existing search history - existing_data = manager.load_from_db() - - if existing_data is None: - search_history = SimpleListManager([]) + if success: + logger.info(f"Updated existing entry with item_id: {item_id} in {location} list") else: - # If existing data is a SimpleListManager, use it; otherwise create new one - if isinstance(existing_data, SimpleListManager): - search_history = existing_data - else: - search_history = SimpleListManager([]) - - # Add new entry and keep only latest 5 - search_history.add_item(str(search_entry)) - if len(search_history) > 5: - # Keep only the latest 5 items - search_history.items = search_history.items[-5:] - - # Save back to Redis - manager.save_to_db(search_history) - - logger.info( - f"Synced search data for user {user_id}, mem_cube {mem_cube_id}. History size: {len(search_history)}" + logger.warning(f"Failed to update entry with item_id: {item_id}") + else: + # Add new entry based on running_status + search_entry = APIMemoryHistoryEntryItem( + item_id=item_id, + query=query, + formatted_memories=formatted_memories, + memories=memories, + task_status=TaskRunningStatus.COMPLETED, + conversation_id=conversation_id, + created_time=get_utc_now(), ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}", exc_info=True) + # Add directly to completed list as APIMemoryHistoryEntryItem instance + search_history.completed_entries.append(search_entry) + + # Maintain window size + if len(search_history.completed_entries) > search_history.window_size: + search_history.completed_entries = search_history.completed_entries[ + -search_history.window_size : + ] + + # Remove from running task IDs + if item_id in search_history.running_item_ids: + search_history.running_item_ids.remove(item_id) + + logger.info(f"Created new entry with item_id: {item_id}") - def get_pre_fine_memories(self, user_id: str, mem_cube_id: str) -> list: + # Update manager's object with the modified search history + manager.obj = search_history + + # Use sync_with_redis to handle Redis synchronization with merging + manager.sync_with_redis(size_limit=self.window_size) + return manager + + def get_pre_memories(self, user_id: str, mem_cube_id: str) -> list: """ - Get the most recent pre-computed fine memories from search history. + Get pre-computed memories from the most recent completed search entry. Args: user_id: User identifier mem_cube_id: Memory cube identifier Returns: - List of formatted memories from the most recent search, or empty list if none found + List of TextualMemoryItem objects from the most recent completed search """ - try: - manager = self.get_search_history_manager(user_id, mem_cube_id) - search_history_key = "search_history_list" - existing_data = manager.load_from_db(search_history_key) + manager = self.get_search_history_manager(user_id, mem_cube_id) - if existing_data is None: - return [] + existing_data = manager.load_from_db() + if existing_data is None: + return [] - search_history = ( - existing_data.obj_instance - if hasattr(existing_data, "obj_instance") - else existing_data - ) + search_history: APISearchHistoryManager = existing_data - if not search_history or len(search_history) == 0: - return [] + # Get memories from the most recent completed entry + history_memories = search_history.get_history_memories(turns=self.pre_memory_turns) + return history_memories - # Return the formatted_memories from the most recent search - latest_entry = search_history[-1] - return ( - latest_entry.get("formatted_memories", []) if isinstance(latest_entry, dict) else [] - ) + def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: + """Get history memories for backward compatibility with tests.""" + manager = self.get_search_history_manager(user_id, mem_cube_id) + existing_data = manager.load_from_db() - except Exception as e: - logger.error(f"Failed to get pre-computed fine memories: {e}", exc_info=True) + if existing_data is None: return [] + + # Handle different data formats + if isinstance(existing_data, APISearchHistoryManager): + search_history = existing_data + else: + # Try to convert to APISearchHistoryManager + try: + search_history = APISearchHistoryManager(**existing_data) + except Exception: + return [] + + return search_history.get_history_memories(turns=n) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index c357e31b5..2e5779f19 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -36,6 +36,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): # Main dispatcher thread pool self.max_workers = max_workers + # Get multi-task timeout from config + self.multi_task_running_timeout = ( + self.config.get("multi_task_running_timeout") if self.config else None + ) + # Only initialize thread pool if in parallel mode self.enable_parallel_dispatch = enable_parallel_dispatch self.thread_name_prefix = "dispatcher" @@ -62,6 +67,8 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): # Task tracking for monitoring self._running_tasks: dict[str, RunningTaskItem] = {} self._task_lock = threading.Lock() + self._completed_tasks = [] + self.completed_tasks_max_show_size = 10 def _create_task_wrapper(self, handler: Callable, task_item: RunningTaskItem): """ @@ -85,7 +92,9 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_completed(result) del self._running_tasks[task_item.item_id] - + self._completed_tasks.append(task_item) + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks[-self.completed_tasks_max_show_size :] logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -95,7 +104,8 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): if task_item.item_id in self._running_tasks: task_item.mark_failed(str(e)) del self._running_tasks[task_item.item_id] - + if len(self._completed_tasks) > self.completed_tasks_max_show_size: + self._completed_tasks[-self.completed_tasks_max_show_size :] logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -356,17 +366,17 @@ def run_competitive_tasks( def run_multiple_tasks( self, - tasks: dict[str, tuple[Callable, tuple, dict]], + tasks: dict[str, tuple[Callable, tuple]], use_thread_pool: bool | None = None, - timeout: float | None = 30.0, + timeout: float | None = None, ) -> dict[str, Any]: """ Execute multiple tasks concurrently and return all results. Args: - tasks: Dictionary mapping task names to (function, args, kwargs) tuples + tasks: Dictionary mapping task names to (task_execution_function, task_execution_parameters) tuples use_thread_pool: Whether to use ThreadPoolExecutor. If None, uses dispatcher's parallel mode setting - timeout: Maximum time to wait for all tasks to complete (in seconds). None for infinite timeout. + timeout: Maximum time to wait for all tasks to complete (in seconds). If None, uses config default. Returns: Dictionary mapping task names to their results @@ -378,7 +388,13 @@ def run_multiple_tasks( if use_thread_pool is None: use_thread_pool = self.enable_parallel_dispatch - logger.info(f"Executing {len(tasks)} tasks concurrently (thread_pool: {use_thread_pool})") + # Use config timeout if not explicitly provided + if timeout is None: + timeout = self.multi_task_running_timeout + + logger.info( + f"Executing {len(tasks)} tasks concurrently (thread_pool: {use_thread_pool}, timeout: {timeout})" + ) try: results = self.thread_manager.run_multiple_tasks( diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index 6f05bf72f..b6f48d043 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -127,7 +127,7 @@ class DictConversionMixin: @field_serializer("timestamp", check_fields=False) def serialize_datetime(self, dt: datetime | None, _info) -> str | None: """ - Custom datetime serialization logic. + Custom timestamp serialization logic. - Supports timezone-aware datetime objects - Compatible with models without timestamp field (via check_fields=False) """ diff --git a/src/memos/mem_scheduler/general_modules/task_threads.py b/src/memos/mem_scheduler/general_modules/task_threads.py index 913d5fa1d..551e8b726 100644 --- a/src/memos/mem_scheduler/general_modules/task_threads.py +++ b/src/memos/mem_scheduler/general_modules/task_threads.py @@ -89,7 +89,7 @@ def worker( def run_multiple_tasks( self, - tasks: dict[str, tuple[Callable, tuple, dict]], + tasks: dict[str, tuple[Callable, tuple]], use_thread_pool: bool = False, timeout: float | None = None, ) -> dict[str, Any]: @@ -97,7 +97,7 @@ def run_multiple_tasks( Run multiple tasks concurrently and return all results. Args: - tasks: Dictionary mapping task names to (function, args, kwargs) tuples + tasks: Dictionary mapping task names to (task_execution_function, task_execution_parameters) tuples use_thread_pool: Whether to use ThreadPoolExecutor (True) or regular threads (False) timeout: Maximum time to wait for all tasks to complete (in seconds). None for infinite timeout. @@ -115,17 +115,21 @@ def run_multiple_tasks( start_time = time.time() if use_thread_pool: - return self.run_with_thread_pool(tasks, timeout) + # Convert tasks format for thread pool compatibility + thread_pool_tasks = {} + for task_name, (func, args) in tasks.items(): + thread_pool_tasks[task_name] = (func, args, {}) + return self.run_with_thread_pool(thread_pool_tasks, timeout) else: # Use regular threads threads = {} thread_results = {} exceptions = {} - def worker(task_name: str, func: Callable, args: tuple, kwargs: dict): + def worker(task_name: str, func: Callable, args: tuple): """Worker function for regular threads""" try: - result = func(*args, **kwargs) + result = func(*args) thread_results[task_name] = result logger.debug(f"Task '{task_name}' completed successfully") except Exception as e: @@ -133,9 +137,9 @@ def worker(task_name: str, func: Callable, args: tuple, kwargs: dict): logger.error(f"Task '{task_name}' failed with error: {e}") # Start all threads - for task_name, (func, args, kwargs) in tasks.items(): + for task_name, (func, args) in tasks.items(): thread = threading.Thread( - target=worker, args=(task_name, func, args, kwargs), name=f"task-{task_name}" + target=worker, args=(task_name, func, args), name=f"task-{task_name}" ) threads[task_name] = thread thread.start() @@ -197,44 +201,60 @@ def run_with_thread_pool( results = {} start_time = time.time() - # Use ThreadPoolExecutor for better resource management - with self.thread_pool_executor as executor: - # Submit all tasks - future_to_name = {} - for task_name, (func, args, kwargs) in tasks.items(): + # Check if executor is shutdown before using it + if self.thread_pool_executor._shutdown: + logger.error("ThreadPoolExecutor is already shutdown, cannot submit new tasks") + raise RuntimeError("ThreadPoolExecutor is already shutdown") + + # Use ThreadPoolExecutor directly without context manager + # The executor lifecycle is managed by the parent SchedulerDispatcher + executor = self.thread_pool_executor + + # Submit all tasks + future_to_name = {} + for task_name, (func, args, kwargs) in tasks.items(): + try: future = executor.submit(func, *args, **kwargs) future_to_name[future] = task_name logger.debug(f"Submitted task '{task_name}' to thread pool") + except RuntimeError as e: + if "cannot schedule new futures after shutdown" in str(e): + logger.error( + f"Cannot submit task '{task_name}': ThreadPoolExecutor is shutdown" + ) + results[task_name] = None + else: + raise - # Collect results as they complete - try: - # Handle infinite timeout case - timeout_param = None if timeout is None else timeout - for future in as_completed(future_to_name, timeout=timeout_param): - task_name = future_to_name[future] - try: - result = future.result() - results[task_name] = result - logger.debug(f"Task '{task_name}' completed successfully") - except Exception as e: - logger.error(f"Task '{task_name}' failed with error: {e}") - results[task_name] = None + # Collect results as they complete + try: + # Handle infinite timeout case + timeout_param = None if timeout is None else timeout + for future in as_completed(future_to_name, timeout=timeout_param): + task_name = future_to_name[future] + try: + result = future.result() + results[task_name] = result + logger.debug(f"Task '{task_name}' completed successfully") + except Exception as e: + logger.error(f"Task '{task_name}' failed with error: {e}") + results[task_name] = None - except Exception: - elapsed_time = time.time() - start_time - timeout_msg = "infinite" if timeout is None else f"{timeout}s" - logger.error( - f"Tasks execution timed out after {elapsed_time:.2f} seconds (timeout: {timeout_msg})" - ) - # Cancel remaining futures - for future in future_to_name: - if not future.done(): - future.cancel() - task_name = future_to_name[future] - logger.warning(f"Cancelled task '{task_name}' due to timeout") - results[task_name] = None - timeout_seconds = "infinite" if timeout is None else timeout - logger.error(f"Tasks execution timed out after {timeout_seconds} seconds") + except Exception: + elapsed_time = time.time() - start_time + timeout_msg = "infinite" if timeout is None else f"{timeout}s" + logger.error( + f"Tasks execution timed out after {elapsed_time:.2f} seconds (timeout: {timeout_msg})" + ) + # Cancel remaining futures + for future in future_to_name: + if not future.done(): + future.cancel() + task_name = future_to_name[future] + logger.warning(f"Cancelled task '{task_name}' due to timeout") + results[task_name] = None + timeout_seconds = "infinite" if timeout is None else timeout + logger.error(f"Tasks execution timed out after {timeout_seconds} seconds") return results diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index fb5f4ce7c..c8e2eb59e 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any +import json + +from typing import TYPE_CHECKING from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -8,18 +10,20 @@ from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( API_MIX_SEARCH_LABEL, - QUERY_LABEL, MemCubeID, SearchMode, UserID, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.utils.api_utils import format_textual_memory_item +from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.types import UserContext if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.reranker.http_bge import HTTPBGEReranker logger = get_logger(__name__) @@ -31,30 +35,18 @@ class OptimizedScheduler(GeneralScheduler): def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) self.api_module = SchedulerAPIModule() - self.message_consumers = { - API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, - } - - def _format_memory_item(self, memory_data: Any) -> dict[str, Any]: - """Format a single memory item for API response.""" - memory = memory_data.model_dump() - memory_id = memory["id"] - ref_id = f"[{memory_id.split('-')[0]}]" - - memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] - memory["metadata"]["ref_id"] = ref_id - memory["metadata"]["id"] = memory_id - memory["metadata"]["memory"] = memory["memory"] - - return memory + self.register_handlers( + { + API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, + } + ) - def fine_search_memories( + def search_memories( self, search_req: APISearchRequest, user_context: UserContext, mem_cube: GeneralMemCube, + mode: SearchMode, ): """Fine search memories function copied from server_router to avoid circular import""" target_session_id = search_req.session_id @@ -67,7 +59,7 @@ def fine_search_memories( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FINE, + mode=mode, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, @@ -77,42 +69,145 @@ def fine_search_memories( "chat_history": search_req.chat_history, }, ) - formatted_memories = [self._format_memory_item(data) for data in search_results] + return search_results + + def submit_memory_history_async_task( + self, + search_req: APISearchRequest, + user_context: UserContext, + ): + # Create message for async fine search + message_content = { + "search_req": { + "query": search_req.query, + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "top_k": search_req.top_k, + "internet_search": search_req.internet_search, + "moscube": search_req.moscube, + "chat_history": search_req.chat_history, + }, + "user_context": {"mem_cube_id": user_context.mem_cube_id}, + } + + async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}" + + # Get mem_cube for the message + mem_cube = self.current_mem_cube + + message = ScheduleMessageItem( + item_id=async_task_id, + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + label=API_MIX_SEARCH_LABEL, + mem_cube=mem_cube, + content=json.dumps(message_content), + timestamp=get_utc_now(), + ) + + # Submit async task + self.submit_messages([message]) + logger.info(f"Submitted async fine search task for user {search_req.user_id}") + return async_task_id + + def mix_search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + ): + """ + Mix search memories: fast search + async fine search + """ + + # Get mem_cube for fast search + mem_cube = self.current_mem_cube + + # Perform fast search + fast_memories = self.search_memories( + search_req=search_req, + user_context=user_context, + mem_cube=mem_cube, + mode=SearchMode.FAST, + ) + + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + ) + + # Try to get pre-computed fine memories if available + pre_fine_memories = self.api_module.get_pre_memories( + user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + ) + if not pre_fine_memories: + # Format fast memories for return + formatted_memories = [format_textual_memory_item(data) for data in fast_memories] + return formatted_memories + + # Merge fast and pre-computed fine memories (both are TextualMemoryItem objects) + combined_memories = fast_memories + pre_fine_memories + # Remove duplicates based on memory content + seen_contents = set() + unique_memories = [] + for memory in combined_memories: + # Both fast_memories and pre_fine_memories are TextualMemoryItem objects + content_key = memory.memory # Use .memory attribute instead of .get("content", "") + if content_key not in seen_contents: + seen_contents.add(content_key) + unique_memories.append(memory) + + # Rerank Memories - reranker expects TextualMemoryItem objects + reranker: HTTPBGEReranker = mem_cube.text_mem.reranker + + # Use search_req parameters for reranking + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + sorted_results = reranker.rerank( + query=search_req.query, # Use search_req.query instead of undefined query + graph_results=unique_memories, # Pass TextualMemoryItem objects directly + top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k + search_filter=search_filter, + ) + + formatted_memories = [ + format_textual_memory_item(item) for item, score in sorted_results[: search_req.top_k] + ] return formatted_memories def update_search_memories_to_redis( - self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] + self, + user_id: str, + mem_cube_id: str, + messages: list[ScheduleMessageItem], ): mem_cube = messages[0].mem_cube - # for status update - self._set_current_context_from_message(msg=messages[0]) - - # update query monitors for msg in messages: - self.monitor.register_query_monitor_if_not_exists( - user_id=user_id, mem_cube_id=mem_cube_id - ) - - content_dict = msg.content + content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - formatted_memories = self.fine_search_memories( - search_req=search_req, user_context=user_context, mem_cube=mem_cube + fine_memories: list[TextualMemoryItem] = self.search_memories( + search_req=APISearchRequest(**content_dict["search_req"]), + user_context=UserContext(**content_dict["user_context"]), + mem_cube=mem_cube, + mode=SearchMode.FINE, ) + formatted_memories = [format_textual_memory_item(data) for data in fine_memories] # Sync search data to Redis - try: - self.api_module.sync_search_data( - user_id=search_req.user_id, - mem_cube_id=user_context.mem_cube_id, - query=search_req.query, - formatted_memories=formatted_memories, - ) - except Exception as e: - logger.error(f"Failed to sync search data: {e}") + self.api_module.sync_search_data( + item_id=msg.item_id, + user_id=search_req["user_id"], + mem_cube_id=user_context["mem_cube_id"], + query=search_req["query"], + memories=fine_memories, + formatted_memories=formatted_memories, + ) def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ @@ -121,12 +216,12 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) Args: messages: List of query messages to process """ - logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") + logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.") # Process the query in a session turn grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) - self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) + self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL) for user_id in grouped_messages: for mem_cube_id in grouped_messages[user_id]: diff --git a/src/memos/mem_scheduler/orm_modules/api_redis_model.py b/src/memos/mem_scheduler/orm_modules/api_redis_model.py new file mode 100644 index 000000000..04cd7e833 --- /dev/null +++ b/src/memos/mem_scheduler/orm_modules/api_redis_model.py @@ -0,0 +1,517 @@ +import os +import time + +from typing import Any + +from sqlalchemy.orm import declarative_base + +from memos.log import get_logger +from memos.mem_scheduler.orm_modules.base_model import DatabaseError +from memos.mem_scheduler.schemas.api_schemas import ( + APISearchHistoryManager, +) +from memos.mem_scheduler.utils.db_utils import get_utc_now + + +logger = get_logger(__name__) + +Base = declarative_base() + + +class APIRedisDBManager: + """Redis-based database manager for any serializable object + + This class handles persistence, synchronization, and locking + for any object that implements to_json/from_json methods using Redis as the backend storage. + """ + + # Add orm_class attribute for compatibility + orm_class = None + + def __init__( + self, + user_id: str | None = None, + mem_cube_id: str | None = None, + obj: APISearchHistoryManager | None = None, + lock_timeout: int = 10, + redis_client=None, + redis_config: dict | None = None, + window_size: int = 5, + ): + """Initialize the Redis database manager + + Args: + user_id: Unique identifier for the user + mem_cube_id: Unique identifier for the memory cube + obj: Optional object instance to manage (must have to_json/from_json methods) + lock_timeout: Timeout in seconds for lock acquisition + redis_client: Redis client instance (optional) + redis_config: Redis configuration dictionary (optional) + """ + # Initialize Redis client + self.redis_client = redis_client + self.redis_config = redis_config or {} + + if self.redis_client is None: + self._init_redis_client() + + # Initialize base attributes without calling parent's init_manager + self.user_id = user_id + self.mem_cube_id = mem_cube_id + self.obj = obj + self.lock_timeout = lock_timeout + self.engine = None # Keep for compatibility but not used + self.SessionLocal = None # Not used for Redis + self.window_size = window_size + self.lock_key = f"{self._get_key_prefix()}:lock" + + logger.info( + f"RedisDBManager initialized for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) + logger.info(f"Redis client: {type(self.redis_client).__name__}") + + # Test Redis connection + try: + self.redis_client.ping() + logger.info("Redis connection successful") + except Exception as e: + logger.warning(f"Redis ping failed: {e}") + # Don't raise error here as it might be a mock client in tests + + def _get_key_prefix(self) -> str: + """Generate Redis key prefix for this user and memory cube + + Returns: + Redis key prefix string + """ + return f"redis_api:{self.user_id}:{self.mem_cube_id}" + + def _get_data_key(self) -> str: + """Generate Redis key for storing serialized data + + Returns: + Redis data key string + """ + return f"{self._get_key_prefix()}:data" + + def _init_redis_client(self): + """Initialize Redis client from config or environment""" + try: + import redis + except ImportError: + logger.error("Redis package not installed. Install with: pip install redis") + raise + + # Try to get Redis client from environment first + if not self.redis_client: + self.redis_client = APIRedisDBManager.load_redis_engine_from_env() + + # If still no client, try from config + if not self.redis_client and self.redis_config: + redis_kwargs = { + "host": self.redis_config.get("host"), + "port": self.redis_config.get("port"), + "db": self.redis_config.get("db"), + "decode_responses": True, + } + + if self.redis_config.get("password"): + redis_kwargs["password"] = self.redis_config["password"] + + self.redis_client = redis.Redis(**redis_kwargs) + + # Final fallback to localhost + if not self.redis_client: + logger.warning("No Redis configuration found, using localhost defaults") + self.redis_client = redis.Redis( + host="localhost", port=6379, db=0, decode_responses=True + ) + + # Test connection + if not self.redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info("Redis client initialized successfully") + + def acquire_lock(self, block: bool = True, **kwargs) -> bool: + """Acquire a distributed lock using Redis with atomic operations + + Args: + block: Whether to block until lock is acquired + **kwargs: Additional filter criteria (ignored for Redis) + + Returns: + True if lock was acquired, False otherwise + """ + + now = get_utc_now() + + # Use Redis SET with NX (only if not exists) and EX (expiry) for atomic lock acquisition + lock_value = f"{self._get_key_prefix()}:{now.timestamp()}" + + while True: + result = self.redis_client.get(self.lock_key) + if result: + # Wait a bit before retrying + logger.info( + f"Waiting for Redis lock to be released for {self.user_id}/{self.mem_cube_id}" + ) + if not block: + logger.warning( + f"Redis lock is held for {self.user_id}/{self.mem_cube_id}, cannot acquire" + ) + return False + else: + time.sleep(0.1) + continue + else: + # Try to acquire lock atomically + result = self.redis_client.set( + self.lock_key, + lock_value, + ex=self.lock_timeout, # Set expiry in seconds + ) + logger.info(f"Redis lock acquired for {self._get_key_prefix()}") + return True + + def release_locks(self, **kwargs): + # Delete the lock key to release the lock + result = self.redis_client.delete(self.lock_key) + + # Redis DELETE returns the number of keys deleted (0 or 1) + if result > 0: + logger.info(f"Redis lock released for {self._get_key_prefix()}") + else: + logger.info(f"No Redis lock found to release for {self._get_key_prefix()}") + + def merge_items( + self, + redis_data: str, + obj_instance: APISearchHistoryManager, + size_limit: int, + ): + """Merge Redis data with current object instance + + Args: + redis_data: JSON string from Redis containing serialized APISearchHistoryManager + obj_instance: Current APISearchHistoryManager instance + size_limit: Maximum number of completed entries to keep + + Returns: + APISearchHistoryManager: Merged and synchronized manager instance + """ + + # Parse Redis data + redis_manager = APISearchHistoryManager.from_json(redis_data) + logger.debug( + f"Loaded Redis manager with {len(redis_manager.completed_entries)} completed and {len(redis_manager.running_item_ids)} running task IDs" + ) + + # Create a new merged manager with the original window size from obj_instance + # Use size_limit only for limiting entries, not as window_size + original_window_size = obj_instance.window_size + merged_manager = APISearchHistoryManager(window_size=original_window_size) + + # Merge completed entries - combine both sources and deduplicate by task_id + # Ensure all entries are APIMemoryHistoryEntryItem instances + from memos.mem_scheduler.schemas.api_schemas import APIMemoryHistoryEntryItem + + all_completed = {} + + # Add Redis completed entries + for entry in redis_manager.completed_entries: + if isinstance(entry, dict): + # Convert dict to APIMemoryHistoryEntryItem instance + try: + entry_obj = APIMemoryHistoryEntryItem(**entry) + task_id = entry_obj.item_id + all_completed[task_id] = entry_obj + except Exception as e: + logger.warning( + f"Failed to convert dict entry to APIMemoryHistoryEntryItem: {e}" + ) + continue + else: + task_id = entry.item_id + all_completed[task_id] = entry + + # Add current instance completed entries (these take priority if duplicated) + for entry in obj_instance.completed_entries: + if isinstance(entry, dict): + # Convert dict to APIMemoryHistoryEntryItem instance + try: + entry_obj = APIMemoryHistoryEntryItem(**entry) + task_id = entry_obj.item_id + all_completed[task_id] = entry_obj + except Exception as e: + logger.warning( + f"Failed to convert dict entry to APIMemoryHistoryEntryItem: {e}" + ) + continue + else: + task_id = entry.item_id + all_completed[task_id] = entry + + # Sort by created_time and apply size limit + completed_list = list(all_completed.values()) + + def get_created_time(entry): + """Helper function to safely extract created_time for sorting""" + from datetime import datetime + + # All entries should now be APIMemoryHistoryEntryItem instances + return getattr(entry, "created_time", datetime.min) + + completed_list.sort(key=get_created_time, reverse=True) + merged_manager.completed_entries = completed_list[:size_limit] + + # Merge running task IDs - combine both sources and deduplicate + all_running_item_ids = set() + + # Add Redis running task IDs + all_running_item_ids.update(redis_manager.running_item_ids) + + # Add current instance running task IDs + all_running_item_ids.update(obj_instance.running_item_ids) + + merged_manager.running_item_ids = list(all_running_item_ids) + + logger.info( + f"Merged manager: {len(merged_manager.completed_entries)} completed, {len(merged_manager.running_item_ids)} running task IDs" + ) + return merged_manager + + def sync_with_redis(self, size_limit: int | None = None) -> None: + """Synchronize data between Redis and the business object + + Args: + size_limit: Optional maximum number of items to keep after synchronization + """ + + # Use window_size from the object if size_limit is not provided + if size_limit is None: + size_limit = self.window_size + + # Acquire lock before operations + lock_status = self.acquire_lock(block=True) + if not lock_status: + logger.error("Failed to acquire Redis lock for synchronization") + return + + # Load existing data from Redis + data_key = self._get_data_key() + redis_data = self.redis_client.get(data_key) + + if redis_data: + # Merge Redis data with current object + merged_obj = self.merge_items( + redis_data=redis_data, obj_instance=self.obj, size_limit=size_limit + ) + + # Update the current object with merged data + self.obj = merged_obj + logger.info( + f"Successfully synchronized with Redis data for {self.user_id}/{self.mem_cube_id}" + ) + else: + logger.info( + f"No existing Redis data found for {self.user_id}/{self.mem_cube_id}, using current object" + ) + + # Save the synchronized object back to Redis + self.save_to_db(self.obj) + + self.release_locks() + + def save_to_db(self, obj_instance: Any) -> None: + """Save the current state of the business object to Redis + + Args: + obj_instance: The object instance to save (must have to_json method) + """ + + data_key = self._get_data_key() + + self.redis_client.set(data_key, obj_instance.to_json()) + + logger.info(f"Updated existing Redis record for {data_key}") + + def load_from_db(self) -> Any | None: + data_key = self._get_data_key() + + # Load from Redis + serialized_data = self.redis_client.get(data_key) + + if not serialized_data: + logger.info(f"No Redis record found for {data_key}") + return None + + # Deserialize the business object using the actual object type + if hasattr(self, "obj_type") and self.obj_type is not None: + db_instance = self.obj_type.from_json(serialized_data) + else: + # Default to APISearchHistoryManager for this class + db_instance = APISearchHistoryManager.from_json(serialized_data) + + logger.info(f"Successfully loaded object from Redis for {data_key} ") + + return db_instance + + @classmethod + def from_env( + cls, + user_id: str, + mem_cube_id: str, + obj: Any | None = None, + lock_timeout: int = 10, + env_file_path: str | None = None, + ) -> "APIRedisDBManager": + """Create RedisDBManager from environment variables + + Args: + user_id: User identifier + mem_cube_id: Memory cube identifier + obj: Optional MemoryMonitorManager instance + lock_timeout: Lock timeout in seconds + env_file_path: Optional path to .env file + + Returns: + RedisDBManager instance + """ + + redis_client = APIRedisDBManager.load_redis_engine_from_env(env_file_path) + return cls( + user_id=user_id, + mem_cube_id=mem_cube_id, + obj=obj, + lock_timeout=lock_timeout, + redis_client=redis_client, + ) + + def close(self): + """Close the Redis connection and clean up resources""" + try: + if hasattr(self.redis_client, "close"): + self.redis_client.close() + logger.info( + f"Redis connection closed for user_id: {self.user_id}, mem_cube_id: {self.mem_cube_id}" + ) + except Exception as e: + logger.warning(f"Error closing Redis connection: {e}") + + @staticmethod + def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: + """Load Redis connection from environment variables + + Args: + env_file_path: Path to .env file (optional, defaults to loading from current environment) + + Returns: + Redis connection instance + + Raises: + DatabaseError: If required environment variables are missing or connection fails + """ + try: + import redis + except ImportError as e: + error_msg = "Redis package not installed. Install with: pip install redis" + logger.error(error_msg) + raise DatabaseError(error_msg) from e + + # Load environment variables from file if provided + if env_file_path: + if os.path.exists(env_file_path): + from dotenv import load_dotenv + + load_dotenv(env_file_path) + logger.info(f"Loaded environment variables from {env_file_path}") + else: + logger.warning( + f"Environment file not found: {env_file_path}, using current environment variables" + ) + else: + logger.info("Using current environment variables (no env_file_path provided)") + + # Get Redis configuration from environment variables + redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") + redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") + redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") + redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") + + # Check required environment variables + if not redis_host: + error_msg = ( + "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" + ) + logger.error(error_msg) + return None + + # Parse port with validation + try: + redis_port = int(redis_port_str) if redis_port_str else 6379 + except ValueError: + error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Parse database with validation + try: + redis_db = int(redis_db_str) if redis_db_str else 0 + except ValueError: + error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." + logger.error(error_msg) + return None + + # Optional timeout settings + socket_timeout = os.getenv( + "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) + ) + socket_connect_timeout = os.getenv( + "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) + ) + + try: + # Build Redis connection parameters + redis_kwargs = { + "host": redis_host, + "port": redis_port, + "db": redis_db, + "decode_responses": True, + } + + if redis_password: + redis_kwargs["password"] = redis_password + + if socket_timeout: + try: + redis_kwargs["socket_timeout"] = float(socket_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" + ) + + if socket_connect_timeout: + try: + redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) + except ValueError: + logger.warning( + f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" + ) + + # Create Redis connection + redis_client = redis.Redis(**redis_kwargs) + + # Test connection + if not redis_client.ping(): + raise ConnectionError("Redis ping failed") + + logger.info( + f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" + ) + return redis_client + + except Exception as e: + error_msg = f"Failed to create Redis connection from environment variables: {e}" + logger.error(error_msg) + raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/orm_modules/base_model.py b/src/memos/mem_scheduler/orm_modules/base_model.py index cf3fc904c..9783cea82 100644 --- a/src/memos/mem_scheduler/orm_modules/base_model.py +++ b/src/memos/mem_scheduler/orm_modules/base_model.py @@ -727,120 +727,3 @@ def load_mysql_engine_from_env(env_file_path: str | None = None) -> Engine | Non error_msg = f"Failed to create MySQL engine from environment variables: {e}" logger.error(error_msg) raise DatabaseError(error_msg) from e - - @staticmethod - def load_redis_engine_from_env(env_file_path: str | None = None) -> Any: - """Load Redis connection from environment variables - - Args: - env_file_path: Path to .env file (optional, defaults to loading from current environment) - - Returns: - Redis connection instance - - Raises: - DatabaseError: If required environment variables are missing or connection fails - """ - try: - import redis - except ImportError as e: - error_msg = "Redis package not installed. Install with: pip install redis" - logger.error(error_msg) - raise DatabaseError(error_msg) from e - - # Load environment variables from file if provided - if env_file_path: - if os.path.exists(env_file_path): - from dotenv import load_dotenv - - load_dotenv(env_file_path) - logger.info(f"Loaded environment variables from {env_file_path}") - else: - logger.warning( - f"Environment file not found: {env_file_path}, using current environment variables" - ) - else: - logger.info("Using current environment variables (no env_file_path provided)") - - # Get Redis configuration from environment variables - redis_host = os.getenv("REDIS_HOST") or os.getenv("MEMSCHEDULER_REDIS_HOST") - redis_port_str = os.getenv("REDIS_PORT") or os.getenv("MEMSCHEDULER_REDIS_PORT") - redis_db_str = os.getenv("REDIS_DB") or os.getenv("MEMSCHEDULER_REDIS_DB") - redis_password = os.getenv("REDIS_PASSWORD") or os.getenv("MEMSCHEDULER_REDIS_PASSWORD") - - # Check required environment variables - if not redis_host: - error_msg = ( - "Missing required Redis environment variable: REDIS_HOST or MEMSCHEDULER_REDIS_HOST" - ) - logger.error(error_msg) - return None - - # Parse port with validation - try: - redis_port = int(redis_port_str) if redis_port_str else 6379 - except ValueError: - error_msg = f"Invalid REDIS_PORT value: {redis_port_str}. Must be a valid integer." - logger.error(error_msg) - return None - - # Parse database with validation - try: - redis_db = int(redis_db_str) if redis_db_str else 0 - except ValueError: - error_msg = f"Invalid REDIS_DB value: {redis_db_str}. Must be a valid integer." - logger.error(error_msg) - return None - - # Optional timeout settings - socket_timeout = os.getenv( - "REDIS_SOCKET_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_TIMEOUT", None) - ) - socket_connect_timeout = os.getenv( - "REDIS_SOCKET_CONNECT_TIMEOUT", os.getenv("MEMSCHEDULER_REDIS_CONNECT_TIMEOUT", None) - ) - - try: - # Build Redis connection parameters - redis_kwargs = { - "host": redis_host, - "port": redis_port, - "db": redis_db, - "decode_responses": True, - } - - if redis_password: - redis_kwargs["password"] = redis_password - - if socket_timeout: - try: - redis_kwargs["socket_timeout"] = float(socket_timeout) - except ValueError: - logger.warning( - f"Invalid REDIS_SOCKET_TIMEOUT value: {socket_timeout}, ignoring" - ) - - if socket_connect_timeout: - try: - redis_kwargs["socket_connect_timeout"] = float(socket_connect_timeout) - except ValueError: - logger.warning( - f"Invalid REDIS_SOCKET_CONNECT_TIMEOUT value: {socket_connect_timeout}, ignoring" - ) - - # Create Redis connection - redis_client = redis.Redis(**redis_kwargs) - - # Test connection - if not redis_client.ping(): - raise ConnectionError("Redis ping failed") - - logger.info( - f"Successfully created Redis connection: {redis_host}:{redis_port}/{redis_db}" - ) - return redis_client - - except Exception as e: - error_msg = f"Failed to create Redis connection from environment variables: {e}" - logger.error(error_msg) - raise DatabaseError(error_msg) from e diff --git a/src/memos/mem_scheduler/schemas/api_schemas.py b/src/memos/mem_scheduler/schemas/api_schemas.py new file mode 100644 index 000000000..23eb5a848 --- /dev/null +++ b/src/memos/mem_scheduler/schemas/api_schemas.py @@ -0,0 +1,232 @@ +from datetime import datetime +from enum import Enum +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, field_serializer + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.misc import DictConversionMixin +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.memories.textual.item import TextualMemoryItem + + +logger = get_logger(__name__) + + +class TaskRunningStatus(str, Enum): + """Enumeration for task running status values.""" + + RUNNING = "running" + COMPLETED = "completed" + + +class APIMemoryHistoryEntryItem(BaseModel, DictConversionMixin): + """Data class for search entry items stored in Redis.""" + + item_id: str = Field( + description="Unique identifier for the task", default_factory=lambda: str(uuid4()) + ) + query: str = Field(..., description="Search query string") + formatted_memories: Any = Field(..., description="Formatted search results") + memories: list[TextualMemoryItem] = Field( + default_factory=list, description="List of TextualMemoryItem objects" + ) + task_status: str = Field( + default="running", description="Task status: running, completed, failed" + ) + conversation_id: str | None = Field( + default=None, description="Optional conversation identifier" + ) + created_time: datetime = Field(description="Entry creation time", default_factory=get_utc_now) + timestamp: datetime | None = Field(default=None, description="Timestamp for the entry") + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + @field_serializer("created_time") + def serialize_created_time(self, value: datetime) -> str: + """Serialize datetime to ISO format string.""" + return value.isoformat() + + def get(self, key: str, default: Any | None = None) -> Any: + """ + Get attribute value by key name, similar to dict.get(). + + Args: + key: The attribute name to retrieve + default: Default value to return if attribute doesn't exist + + Returns: + The attribute value or default if not found + """ + return getattr(self, key, default) + + +class APISearchHistoryManager(BaseModel, DictConversionMixin): + """ + Data structure for managing search history with separate completed and running entries. + Supports window_size to limit the number of completed entries. + """ + + window_size: int = Field(default=5, description="Maximum number of completed entries to keep") + completed_entries: list[APIMemoryHistoryEntryItem] = Field( + default_factory=list, description="List of completed search entries" + ) + running_item_ids: list[str] = Field( + default_factory=list, description="List of running task ids" + ) + + model_config = ConfigDict( + arbitrary_types_allowed=True, + validate_assignment=True, + ) + + def complete_entry(self, task_id: str) -> bool: + """ + Remove task_id from running list when completed. + Note: The actual entry data should be managed separately. + + Args: + task_id: The task ID to complete + + Returns: + True if task_id was found and removed, False otherwise + """ + if task_id in self.running_item_ids: + self.running_item_ids.remove(task_id) + logger.debug(f"Completed task_id: {task_id}") + return True + + logger.warning(f"Task ID {task_id} not found in running task ids") + return False + + def get_running_item_ids(self) -> list[str]: + """Get all running task IDs""" + return self.running_item_ids.copy() + + def get_completed_entries(self) -> list[dict[str, Any]]: + """Get all completed entries""" + return self.completed_entries.copy() + + def get_history_memory_entries(self, turns: int | None = None) -> list[dict[str, Any]]: + """ + Get the most recent n completed search entries, sorted by created_time. + + Args: + turns: Number of entries to return. If None, returns all completed entries. + + Returns: + List of completed search entries, sorted by created_time (newest first) + """ + if not self.completed_entries: + return [] + + # Sort by created_time (newest first) + sorted_entries = sorted(self.completed_entries, key=lambda x: x.created_time, reverse=True) + + if turns is None: + return sorted_entries + + return sorted_entries[:turns] + + def get_history_memories(self, turns: int | None = None) -> list[TextualMemoryItem]: + """ + Get the most recent n completed search entries, sorted by created_time. + + Args: + turns: Number of entries to return. If None, returns all completed entries. + + Returns: + List of TextualMemoryItem objects from completed entries, sorted by created_time (newest first) + """ + sorted_entries = self.get_history_memory_entries(turns=turns) + + memories = [] + for one in sorted_entries: + memories.extend(one.memories) + return memories + + def find_entry_by_item_id(self, item_id: str) -> tuple[dict[str, Any] | None, str]: + """ + Find an entry by item_id in completed list only. + Running entries are now just task IDs, so we can only search completed entries. + + Args: + item_id: The item ID to search for + + Returns: + Tuple of (entry_dict, location) where location is 'completed' or 'not_found' + """ + # Check completed entries + for entry in self.completed_entries: + try: + if hasattr(entry, "item_id") and entry.item_id == item_id: + return entry.to_dict(), "completed" + elif isinstance(entry, dict) and entry.get("item_id") == item_id: + return entry, "completed" + except AttributeError as e: + logger.warning(f"Entry missing item_id attribute: {e}, entry type: {type(entry)}") + continue + + return None, "not_found" + + def update_entry_by_item_id( + self, + item_id: str, + query: str, + formatted_memories: Any, + task_status: TaskRunningStatus, + conversation_id: str | None = None, + memories: list[TextualMemoryItem] | None = None, + ) -> bool: + """ + Update an existing entry by item_id. Since running entries are now just IDs, + this method can only update completed entries. + + Args: + item_id: The item ID to update + query: New query string + formatted_memories: New formatted memories + task_status: New task status + conversation_id: New conversation ID + memories: List of TextualMemoryItem objects + + Returns: + True if entry was found and updated, False otherwise + """ + # Find the entry in completed list + for entry in self.completed_entries: + if entry.item_id == item_id: + # Update the entry content + entry.query = query + entry.formatted_memories = formatted_memories + entry.task_status = task_status + if conversation_id is not None: + entry.conversation_id = conversation_id + if memories is not None: + entry.memories = memories + + logger.debug(f"Updated entry with item_id: {item_id}, new status: {task_status}") + return True + + logger.warning(f"Entry with item_id: {item_id} not found in completed entries") + return False + + def get_total_count(self) -> dict[str, int]: + """Get count of entries by status""" + return { + "completed": len(self.completed_entries), + "running": len(self.running_item_ids), + "total": len(self.completed_entries) + len(self.running_item_ids), + } + + def __len__(self) -> int: + """Return total number of entries (completed + running)""" + return len(self.completed_entries) + len(self.running_item_ids) + + +# Alias for easier usage +SearchHistoryManager = APISearchHistoryManager diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 2bc7a3b98..a2c6434fe 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -39,6 +39,7 @@ class SearchMode(str, Enum): DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = False +DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 # startup mode configuration STARTUP_BY_THREAD = "thread" diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 541d2486d..bd3155a96 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -6,7 +6,7 @@ from typing_extensions import TypedDict from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -35,10 +35,9 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) user_id: str = Field(..., description="user id") - session_id: str | None = Field(default=None, description="session id") mem_cube_id: str = Field(..., description="memcube id") label: str = Field(..., description="Label of the schedule message") - mem_cube: GeneralMemCube | str = Field(..., description="memcube for schedule") + mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" @@ -56,7 +55,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "example": { "item_id": "123e4567-e89b-12d3-a456-426614174000", # Sample UUID "user_id": "user123", # Example user identifier - "session_id": "session123", # Example session identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value "mem_cube": "obj of GeneralMemCube", # Added mem_cube example @@ -67,18 +65,17 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): ) @field_serializer("mem_cube") - def serialize_mem_cube(self, cube: GeneralMemCube | str, _info) -> str: - """Custom serializer for GeneralMemCube objects to string representation""" + def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str: + """Custom serializer for BaseMemCube objects to string representation""" if isinstance(cube, str): return cube - return f"" + return f"<{type(cube).__name__}:{id(cube)}>" def to_dict(self) -> dict: """Convert model to dictionary suitable for Redis Stream""" return { "item_id": self.item_id, "user_id": self.user_id, - "session_id": self.session_id, "cube_id": self.mem_cube_id, "label": self.label, "cube": "Not Applicable", # Custom cube serialization @@ -93,8 +90,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": item_id=data.get("item_id", str(uuid4())), user_id=data["user_id"], mem_cube_id=data["cube_id"], - session_id=data["session_id"], - cube_id=data["cube_id"], label=data["label"], mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], diff --git a/src/memos/mem_scheduler/utils/api_utils.py b/src/memos/mem_scheduler/utils/api_utils.py new file mode 100644 index 000000000..c8d096517 --- /dev/null +++ b/src/memos/mem_scheduler/utils/api_utils.py @@ -0,0 +1,76 @@ +import uuid + +from typing import Any + +from memos.memories.textual.item import TreeNodeTextualMemoryMetadata +from memos.memories.textual.tree import TextualMemoryItem + + +def format_textual_memory_item(memory_data: Any) -> dict[str, Any]: + """Format a single memory item for API response.""" + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +def make_textual_item(memory_data): + return memory_data + + +def text_to_textual_memory_item( + text: str, + user_id: str | None = None, + session_id: str | None = None, + memory_type: str = "WorkingMemory", + tags: list[str] | None = None, + key: str | None = None, + sources: list | None = None, + background: str = "", + confidence: float = 0.99, + embedding: list[float] | None = None, +) -> TextualMemoryItem: + """ + Convert text into a TextualMemoryItem object. + + Args: + text: Memory content text + user_id: User ID + session_id: Session ID + memory_type: Memory type, defaults to "WorkingMemory" + tags: List of tags + key: Memory key or title + sources: List of sources + background: Background information + confidence: Confidence score (0-1) + embedding: Vector embedding + + Returns: + TextualMemoryItem: Wrapped memory item + """ + return TextualMemoryItem( + id=str(uuid.uuid4()), + memory=text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type=memory_type, + status="activated", + tags=tags or [], + key=key, + embedding=embedding or [], + usage=[], + sources=sources or [], + background=background, + confidence=confidence, + type="fact", + ), + ) diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 239557bc9..d86911e82 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -273,7 +273,7 @@ def _cleanup_redis_resources(self): self._cleanup_local_redis() - async def redis_add_message_stream(self, message: dict): + def redis_add_message_stream(self, message: dict): logger.debug(f"add_message_stream: {message}") return self._redis_conn.xadd("user:queries:stream", message) diff --git a/src/memos/memories/activation/item.py b/src/memos/memories/activation/item.py index ba1619371..9267e6920 100644 --- a/src/memos/memories/activation/item.py +++ b/src/memos/memories/activation/item.py @@ -6,6 +6,8 @@ from pydantic import BaseModel, ConfigDict, Field from transformers import DynamicCache +from memos.mem_scheduler.utils.db_utils import get_utc_now + class ActivationMemoryItem(BaseModel): id: str = Field(default_factory=lambda: str(uuid.uuid4())) @@ -23,7 +25,7 @@ class KVCacheRecords(BaseModel): description="Single string combining all text_memories using assembly template", ) timestamp: datetime = Field( - default_factory=datetime.utcnow, description="submit time for schedule_messages" + default_factory=get_utc_now, description="submit time for schedule_messages" ) diff --git a/tests/mem_scheduler/test_orm.py b/tests/mem_scheduler/test_orm.py deleted file mode 100644 index a43231e4a..000000000 --- a/tests/mem_scheduler/test_orm.py +++ /dev/null @@ -1,447 +0,0 @@ -import os -import tempfile -import time - -from datetime import datetime, timedelta - -import pytest - -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager - -# Import the classes to test -from memos.mem_scheduler.orm_modules.monitor_models import ( - DBManagerForMemoryMonitorManager, - DBManagerForQueryMonitorQueue, -) -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager -from memos.mem_scheduler.schemas.monitor_schemas import ( - MemoryMonitorItem, - MemoryMonitorManager, - QueryMonitorItem, - QueryMonitorQueue, -) - - -# Test data -TEST_USER_ID = "test_user" -TEST_MEM_CUBE_ID = "test_mem_cube" -TEST_QUEUE_ID = "test_queue" - - -class TestBaseDBManager: - """Base class for DBManager tests with common fixtures""" - - @pytest.fixture - def temp_db(self): - """Create a temporary database for testing.""" - temp_dir = tempfile.mkdtemp() - db_path = os.path.join(temp_dir, "test_scheduler_orm.db") - yield db_path - # Cleanup - try: - if os.path.exists(db_path): - os.remove(db_path) - os.rmdir(temp_dir) - except (OSError, PermissionError): - pass # Ignore cleanup errors (e.g., file locked on Windows) - - @pytest.fixture - def memory_manager_obj(self): - """Create a MemoryMonitorManager object for testing""" - return MemoryMonitorManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - items=[ - MemoryMonitorItem( - item_id="custom-id-123", - memory_text="Full test memory", - tree_memory_item=None, - tree_memory_item_mapping_key="full_test_key", - keywords_score=0.8, - sorting_score=0.9, - importance_score=0.7, - recording_count=3, - ) - ], - ) - - @pytest.fixture - def query_queue_obj(self): - """Create a QueryMonitorQueue object for testing""" - queue = QueryMonitorQueue() - queue.put( - QueryMonitorItem( - item_id="query1", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text="How are you?", - timestamp=datetime.now(), - keywords=["how", "you"], - ) - ) - return queue - - @pytest.fixture - def query_monitor_manager(self, temp_db, query_queue_obj): - """Create DBManagerForQueryMonitorQueue instance with temp DB.""" - engine = BaseDBManager.create_engine_from_db_path(temp_db) - manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - - assert manager.engine is not None - assert manager.SessionLocal is not None - assert os.path.exists(temp_db) - - yield manager - manager.close() - - @pytest.fixture - def memory_monitor_manager(self, temp_db, memory_manager_obj): - """Create DBManagerForMemoryMonitorManager instance with temp DB.""" - engine = BaseDBManager.create_engine_from_db_path(temp_db) - manager = DBManagerForMemoryMonitorManager( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=memory_manager_obj, - lock_timeout=10, - ) - - assert manager.engine is not None - assert manager.SessionLocal is not None - assert os.path.exists(temp_db) - - yield manager - manager.close() - - def test_save_and_load_query_queue(self, query_monitor_manager, query_queue_obj): - """Test saving and loading QueryMonitorQueue.""" - # Save to database - query_monitor_manager.save_to_db(query_queue_obj) - - # Load in a new manager - engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) - new_manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=None, - lock_timeout=10, - ) - loaded_queue = new_manager.load_from_db(acquire_lock=True) - - assert loaded_queue is not None - items = loaded_queue.get_queue_content_without_pop() - assert len(items) == 1 - assert items[0].item_id == "query1" - assert items[0].query_text == "How are you?" - new_manager.close() - - def test_lock_mechanism(self, query_monitor_manager, query_queue_obj): - """Test lock acquisition and release.""" - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - # Acquire lock - acquired = query_monitor_manager.acquire_lock(block=True) - assert acquired - - # Try to acquire again (should fail without blocking) - assert not query_monitor_manager.acquire_lock(block=False) - - # Release lock - query_monitor_manager.release_locks( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - ) - - # Should be able to acquire again - assert query_monitor_manager.acquire_lock(block=False) - - def test_lock_timeout(self, query_monitor_manager, query_queue_obj): - """Test lock timeout mechanism.""" - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - query_monitor_manager.lock_timeout = 1 - - # Acquire lock - assert query_monitor_manager.acquire_lock(block=True) - - # Wait for lock to expire - time.sleep(1.1) - - # Should be able to acquire again - assert query_monitor_manager.acquire_lock(block=False) - - def test_sync_with_orm(self, query_monitor_manager, query_queue_obj): - """Test synchronization between ORM and object.""" - query_queue_obj.put( - QueryMonitorItem( - item_id="query2", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text="What's your name?", - timestamp=datetime.now(), - keywords=["name"], - ) - ) - - # Save current state - query_monitor_manager.save_to_db(query_queue_obj) - - # Create sync manager with empty queue - empty_queue = QueryMonitorQueue(maxsize=10) - engine = BaseDBManager.create_engine_from_db_path(query_monitor_manager.engine.url.database) - sync_manager = DBManagerForQueryMonitorQueue( - engine=engine, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=empty_queue, - lock_timeout=10, - ) - - # First sync - should create a new record with empty queue - sync_manager.sync_with_orm(size_limit=None) - items = sync_manager.obj.get_queue_content_without_pop() - assert len(items) == 0 # Empty queue since no existing data to merge - - # Now save the empty queue to create a record - sync_manager.save_to_db(empty_queue) - - # Test that sync_with_orm correctly handles version control - # The sync should increment version but not merge data when versions are the same - sync_manager.sync_with_orm(size_limit=None) - items = sync_manager.obj.get_queue_content_without_pop() - assert len(items) == 0 # Should remain empty since no merge occurred - - # Verify that the version was incremented - assert sync_manager.last_version_control == "3" # Should increment from 2 to 3 - - sync_manager.close() - - def test_sync_with_size_limit(self, query_monitor_manager, query_queue_obj): - """Test synchronization with size limit.""" - now = datetime.now() - item_size = 1 - for i in range(2, 6): - item_size += 1 - query_queue_obj.put( - QueryMonitorItem( - item_id=f"query{i}", - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - query_text=f"Question {i}", - timestamp=now + timedelta(minutes=i), - keywords=[f"kw{i}"], - ) - ) - - # First sync - should create a new record (size_limit not applied for new records) - size_limit = 3 - query_monitor_manager.sync_with_orm(size_limit=size_limit) - items = query_monitor_manager.obj.get_queue_content_without_pop() - assert len(items) == item_size # All items since size_limit not applied for new records - - # Save to create the record - query_monitor_manager.save_to_db(query_monitor_manager.obj) - - # Test that sync_with_orm correctly handles version control - # The sync should increment version but not merge data when versions are the same - query_monitor_manager.sync_with_orm(size_limit=size_limit) - items = query_monitor_manager.obj.get_queue_content_without_pop() - assert len(items) == item_size # Should remain the same since no merge occurred - - # Verify that the version was incremented - assert query_monitor_manager.last_version_control == "2" - - def test_concurrent_access(self, temp_db, query_queue_obj): - """Test concurrent access to the same database.""" - - # Manager 1 - engine1 = BaseDBManager.create_engine_from_db_path(temp_db) - manager1 = DBManagerForQueryMonitorQueue( - engine=engine1, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - manager1.save_to_db(query_queue_obj) - - # Manager 2 - engine2 = BaseDBManager.create_engine_from_db_path(temp_db) - manager2 = DBManagerForQueryMonitorQueue( - engine=engine2, - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=query_queue_obj, - lock_timeout=10, - ) - - # Manager1 acquires lock - assert manager1.acquire_lock(block=True) - - # Manager2 fails to acquire - assert not manager2.acquire_lock(block=False) - - # Manager1 releases - manager1.release_locks(user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID) - - # Manager2 can now acquire - assert manager2.acquire_lock(block=False) - - manager1.close() - manager2.close() - - -class TestRedisDBManager: - """Test class for RedisDBManager functionality""" - - @pytest.fixture - def memory_manager_obj(self): - """Create a MemoryMonitorManager object for testing""" - return MemoryMonitorManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - memories=[ - MemoryMonitorItem( - item_id="redis-test-123", - memory_text="Redis test memory", - tree_memory_item=None, - tree_memory_item_mapping_key="redis_test_key", - keywords_score=0.8, - sorting_score=0.9, - importance_score=0.7, - recording_count=3, - ) - ], - ) - - @pytest.fixture - def mock_redis_client(self): - """Create a mock Redis client for testing""" - try: - from unittest.mock import MagicMock - - # Create a mock Redis client - mock_client = MagicMock() - - # Mock Redis data storage - mock_data = {} - - def mock_set(key, value, nx=False, ex=None, **kwargs): - if nx and key in mock_data: - # NX means "only set if not exists" - return False # Redis returns False when NX fails - mock_data[key] = value - return True - - def mock_get(key): - return mock_data.get(key) - - def mock_hset(key, mapping=None, **kwargs): - if key not in mock_data: - mock_data[key] = {} - if mapping: - mock_data[key].update(mapping) - if kwargs: - mock_data[key].update(kwargs) - return len(mapping) if mapping else len(kwargs) - - def mock_hgetall(key): - return mock_data.get(key, {}) - - def mock_delete(*keys): - deleted = 0 - for key in keys: - if key in mock_data: - del mock_data[key] - deleted += 1 - return deleted - - def mock_keys(pattern): - import fnmatch - - return [key for key in mock_data if fnmatch.fnmatch(key, pattern)] - - def mock_ping(): - return True - - def mock_close(): - pass - - # Configure mock methods - mock_client.set = mock_set - mock_client.get = mock_get - mock_client.hset = mock_hset - mock_client.hgetall = mock_hgetall - mock_client.delete = mock_delete - mock_client.keys = mock_keys - mock_client.ping = mock_ping - mock_client.close = mock_close - - return mock_client - - except ImportError: - pytest.skip("Redis package not available for testing") - - @pytest.fixture - def redis_manager(self, mock_redis_client, memory_manager_obj): - """Create RedisDBManager instance with mock Redis client""" - manager = RedisDBManager( - user_id=TEST_USER_ID, - mem_cube_id=TEST_MEM_CUBE_ID, - obj=memory_manager_obj, - lock_timeout=10, - redis_client=mock_redis_client, - ) - yield manager - manager.close() - - def test_redis_manager_initialization(self, mock_redis_client): - """Test RedisDBManager initialization""" - manager = RedisDBManager( - user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID, redis_client=mock_redis_client - ) - - assert manager.user_id == TEST_USER_ID - assert manager.mem_cube_id == TEST_MEM_CUBE_ID - assert manager.redis_client is mock_redis_client - assert manager.orm_class.__name__ == "RedisLockableORM" - assert manager.obj_class == MemoryMonitorManager - - manager.close() - - def test_redis_lockable_orm_save_load(self, mock_redis_client): - """Test RedisLockableORM save and load operations""" - from memos.mem_scheduler.orm_modules.redis_model import RedisLockableORM - - orm = RedisLockableORM( - redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID - ) - - # Test save - orm.serialized_data = '{"test": "data"}' - orm.version_control = "1" - orm.lock_acquired = True - orm.lock_expiry = datetime.now() - - orm.save() - - # Test load - new_orm = RedisLockableORM( - redis_client=mock_redis_client, user_id=TEST_USER_ID, mem_cube_id=TEST_MEM_CUBE_ID - ) - - exists = new_orm.load() - assert exists - assert new_orm.serialized_data == '{"test": "data"}' - assert new_orm.version_control == "1" - # Note: lock_acquired is False after load by design - locks are managed separately - assert not new_orm.lock_acquired diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 369b4a6f1..03a8e4318 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -204,7 +204,6 @@ def test_scheduler_startup_mode_thread(self): def test_redis_message_queue(self): """Test Redis message queue functionality for sending and receiving messages.""" - import asyncio import time from unittest.mock import MagicMock, patch @@ -244,7 +243,7 @@ def redis_handler(messages: list[ScheduleMessageItem]) -> None: ) # Submit message to Redis queue - asyncio.run(self.scheduler.submit_messages(redis_message)) + self.scheduler.submit_messages(redis_message) # Verify Redis xadd was called mock_redis.xadd.assert_called_once() @@ -529,55 +528,3 @@ def test_get_running_tasks_multiple_tasks(self): # Verify dispatcher method was called mock_get_running_tasks.assert_called_once_with(filter_func=None) - - def test_message_handler_receives_submitted_message(self): - """Test that handlers receive messages after scheduler startup and message submission.""" - # Create a mock handler that tracks received messages - received_messages = [] - - def mock_handler(messages: list[ScheduleMessageItem]) -> None: - """Mock handler that records received messages.""" - received_messages.extend(messages) - - # Register the mock handler - test_label = "test_handler" - handlers = {test_label: mock_handler} - self.scheduler.register_handlers(handlers) - - # Verify handler is registered - self.assertIn(test_label, self.scheduler.handlers) - self.assertEqual(self.scheduler.handlers[test_label], mock_handler) - - # Start the scheduler - self.scheduler.start() - - # Create and submit a test message - test_message = ScheduleMessageItem( - label=test_label, - content="Test message content", - user_id="test_user", - mem_cube_id="test_mem_cube", - mem_cube="test_mem_cube_obj", # Required field - can be string or GeneralMemCube - timestamp=datetime.now(), - ) - - import asyncio - - asyncio.run(self.scheduler.submit_messages(test_message)) - - # Wait for message processing to complete - import time - - time.sleep(2.0) # Allow sufficient time for message processing - - # Verify the handler received the message - self.assertEqual( - len(received_messages), 1, f"Expected 1 message, got {len(received_messages)}" - ) - self.assertEqual(received_messages[0].label, test_label) - self.assertEqual(received_messages[0].content, "Test message content") - self.assertEqual(received_messages[0].user_id, "test_user") - self.assertEqual(received_messages[0].mem_cube_id, "test_mem_cube") - - # Stop the scheduler - self.scheduler.stop()