diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 395d3fbc7..03622922d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -866,6 +866,7 @@ def create_user_config(user_name: str, user_id: str) -> tuple[MOSConfig, General "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, "search_strategy": { + "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, @@ -937,6 +938,7 @@ def get_default_cube_config() -> GeneralMemCubeConfig | None: "UserMemory": os.getenv("NEBULAR_USER_MEMORY", 1e6), }, "search_strategy": { + "fast_graph": bool(os.getenv("FAST_GRAPH", "false") == "true"), "bm25": bool(os.getenv("BM25_CALL", "false") == "true"), "cot": bool(os.getenv("VEC_COT_CALL", "false") == "true"), }, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 6974dbe8f..992b7bfab 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -70,12 +70,6 @@ def __init__( ) logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) - time_start_rr = time.time() self.reranker = reranker logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") @@ -189,7 +183,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=None, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -200,7 +194,7 @@ def search( bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, moscube=moscube, - vec_cot=self.vec_cot, + search_strategy=self.search_strategy, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index a58f993bb..19bd3ba5b 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -51,11 +51,6 @@ def __init__(self, config: TreeTextMemoryConfig): self.bm25_retriever = ( EnhancedBM25() if self.search_strategy and self.search_strategy["bm25"] else None ) - self.vec_cot = ( - self.search_strategy["cot"] - if self.search_strategy and "cot" in self.search_strategy - else False - ) if config.reranker is None: default_cfg = RerankerConfigFactory.model_validate( @@ -143,6 +138,7 @@ def get_searcher( self.reranker, internet_retriever=None, moscube=moscube, + search_strategy=self.search_strategy, ) else: searcher = Searcher( @@ -152,6 +148,7 @@ def get_searcher( self.reranker, internet_retriever=self.internet_retriever, moscube=moscube, + search_strategy=self.search_strategy, ) return searcher diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index b7383aa13..8cf2f47f3 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -40,6 +40,7 @@ def retrieve( search_filter: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, + use_fast_graph: bool = False, ) -> list[TextualMemoryItem]: """ Perform hybrid memory retrieval: @@ -69,7 +70,13 @@ def retrieve( with ContextThreadPoolExecutor(max_workers=3) as executor: # Structured graph-based retrieval - future_graph = executor.submit(self._graph_recall, parsed_goal, memory_scope, user_name) + future_graph = executor.submit( + self._graph_recall, + parsed_goal, + memory_scope, + user_name, + use_fast_graph=use_fast_graph, + ) # Vector similarity search future_vector = executor.submit( self._vector_recall, @@ -155,7 +162,7 @@ def retrieve_from_cube( return list(combined.values()) def _graph_recall( - self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None + self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: """ Perform structured node-based retrieval from Neo4j. @@ -163,6 +170,7 @@ def _graph_recall( - tags must overlap with at least 2 input tags - scope filters by memory_type if provided """ + use_fast_graph = kwargs.get("use_fast_graph", False) def process_node(node): meta = node.get("metadata", {}) @@ -184,47 +192,96 @@ def process_node(node): return TextualMemoryItem.from_dict(node) return None - candidate_ids = set() - - # 1) key-based OR branch - if parsed_goal.keys: - key_filters = [ - {"field": "key", "op": "in", "value": parsed_goal.keys}, - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) - candidate_ids.update(key_ids) - - # 2) tag-based OR branch - if parsed_goal.tags: - tag_filters = [ - {"field": "tags", "op": "contains", "value": parsed_goal.tags}, - {"field": "memory_type", "op": "=", "value": memory_scope}, - ] - tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) - candidate_ids.update(tag_ids) - - # No matches → return empty - if not candidate_ids: - return [] + if not use_fast_graph: + candidate_ids = set() - # Load nodes and post-filter - node_dicts = self.graph_store.get_nodes( - list(candidate_ids), include_embedding=False, user_name=user_name - ) + # 1) key-based OR branch + if parsed_goal.keys: + key_filters = [ + {"field": "key", "op": "in", "value": parsed_goal.keys}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + key_ids = self.graph_store.get_by_metadata(key_filters) + candidate_ids.update(key_ids) + + # 2) tag-based OR branch + if parsed_goal.tags: + tag_filters = [ + {"field": "tags", "op": "contains", "value": parsed_goal.tags}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + tag_ids = self.graph_store.get_by_metadata(tag_filters) + candidate_ids.update(tag_ids) - final_nodes = [] - with ContextThreadPoolExecutor(max_workers=3) as executor: - futures = {executor.submit(process_node, node): i for i, node in enumerate(node_dicts)} - temp_results = [None] * len(node_dicts) + # No matches → return empty + if not candidate_ids: + return [] + + # Load nodes and post-filter + node_dicts = self.graph_store.get_nodes(list(candidate_ids), include_embedding=False) + + final_nodes = [] + for node in node_dicts: + meta = node.get("metadata", {}) + node_key = meta.get("key") + node_tags = meta.get("tags", []) or [] + + keep = False + # key equals to node_key + if parsed_goal.keys and node_key in parsed_goal.keys: + keep = True + # overlap tags more than 2 + elif parsed_goal.tags: + overlap = len(set(node_tags) & set(parsed_goal.tags)) + if overlap >= 2: + keep = True + if keep: + final_nodes.append(TextualMemoryItem.from_dict(node)) + return final_nodes + else: + candidate_ids = set() + + # 1) key-based OR branch + if parsed_goal.keys: + key_filters = [ + {"field": "key", "op": "in", "value": parsed_goal.keys}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + key_ids = self.graph_store.get_by_metadata(key_filters, user_name=user_name) + candidate_ids.update(key_ids) + + # 2) tag-based OR branch + if parsed_goal.tags: + tag_filters = [ + {"field": "tags", "op": "contains", "value": parsed_goal.tags}, + {"field": "memory_type", "op": "=", "value": memory_scope}, + ] + tag_ids = self.graph_store.get_by_metadata(tag_filters, user_name=user_name) + candidate_ids.update(tag_ids) + + # No matches → return empty + if not candidate_ids: + return [] + + # Load nodes and post-filter + node_dicts = self.graph_store.get_nodes( + list(candidate_ids), include_embedding=False, user_name=user_name + ) + + final_nodes = [] + with ContextThreadPoolExecutor(max_workers=3) as executor: + futures = { + executor.submit(process_node, node): i for i, node in enumerate(node_dicts) + } + temp_results = [None] * len(node_dicts) - for future in concurrent.futures.as_completed(futures): - original_index = futures[future] - result = future.result() - temp_results[original_index] = result + for future in concurrent.futures.as_completed(futures): + original_index = futures[future] + result = future.result() + temp_results[original_index] = result - final_nodes = [result for result in temp_results if result is not None] - return final_nodes + final_nodes = [result for result in temp_results if result is not None] + return final_nodes def _vector_recall( self, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 563695c68..0974d67f2 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -45,7 +45,7 @@ def __init__( bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, moscube: bool = False, - vec_cot: bool = False, + search_strategy: dict | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -59,7 +59,12 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever self.moscube = moscube - self.vec_cot = vec_cot + self.vec_cot = ( + search_strategy.get("vec_cot", "false") == "true" if search_strategy else False + ) + self.use_fast_graph = ( + search_strategy.get("fast_graph", "false") == "true" if search_strategy else False + ) self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @@ -226,6 +231,7 @@ def _parse_task( context="\n".join(context), conversation=info.get("chat_history", []), mode=mode, + use_fast_graph=self.use_fast_graph, ) query = parsed_goal.rephrased_query or query @@ -340,6 +346,7 @@ def _retrieve_from_working_memory( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) return self.reranker.rerank( query=query, @@ -390,6 +397,7 @@ def _retrieve_from_long_term_and_user( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) ) if memory_type in ["All", "UserMemory"]: @@ -404,6 +412,7 @@ def _retrieve_from_long_term_and_user( search_filter=search_filter, user_name=user_name, id_filter=id_filter, + use_fast_graph=self.use_fast_graph, ) ) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 6a1138c90..5d706559c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -29,6 +29,7 @@ def parse( context: str = "", conversation: list[dict] | None = None, mode: str = "fast", + **kwargs, ) -> ParsedTaskGoal: """ Parse user input into structured semantic layers. @@ -38,7 +39,7 @@ def parse( - mode == 'fine': use LLM to parse structured topic/keys/tags """ if mode == "fast": - return self._parse_fast(task_description) + return self._parse_fast(task_description, **kwargs) elif mode == "fine": if not self.llm: raise ValueError("LLM not provided for slow mode.") @@ -46,19 +47,30 @@ def parse( else: raise ValueError(f"Unknown mode: {mode}") - def _parse_fast(self, task_description: str, limit_num: int = 5) -> ParsedTaskGoal: + def _parse_fast(self, task_description: str, **kwargs) -> ParsedTaskGoal: """ Fast mode: simple jieba word split. """ - desc_tokenized = self.tokenizer.tokenize_mixed(task_description) - return ParsedTaskGoal( - memories=[task_description], - keys=desc_tokenized, - tags=desc_tokenized, - goal_type="default", - rephrased_query=task_description, - internet_search=False, - ) + use_fast_graph = kwargs.get("use_fast_graph", False) + if use_fast_graph: + desc_tokenized = self.tokenizer.tokenize_mixed(task_description) + return ParsedTaskGoal( + memories=[task_description], + keys=desc_tokenized, + tags=desc_tokenized, + goal_type="default", + rephrased_query=task_description, + internet_search=False, + ) + else: + return ParsedTaskGoal( + memories=[task_description], + keys=[task_description], + tags=[], + goal_type="default", + rephrased_query=task_description, + internet_search=False, + ) def _parse_fine( self, query: str, context: str = "", conversation: list[dict] | None = None