diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index eb15b48ed..37a5f1a8b 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -71,6 +71,32 @@ def __init__( self.manual_close_internet = manual_close_internet self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") + self._retrieve_executor = ContextThreadPoolExecutor(max_workers=5, thread_name_prefix="retrieve") + self.retrieve_timeout_seconds = 20.0 + def close(self) -> None: + """Release background resources (executors). + + Safe to call multiple times. + """ + for ex_name in ("_usage_executor", "_retrieve_executor"): + ex = getattr(self, ex_name, None) + if ex is None: + continue + try: + ex.shutdown(wait=False, cancel_futures=True) + except TypeError: + # Python<3.9 compatibility: cancel_futures not supported + ex.shutdown(wait=False) + except Exception: + logger.debug("[Searcher] failed to shutdown %s", ex_name, exc_info=True) + + def __del__(self): + try: + self.close() + except Exception: + pass + + @timed def retrieve( @@ -346,7 +372,11 @@ def _retrieve_paths( include_preference_memory: bool = False, pref_mem_top_k: int = 6, ): - """Run A/B/C/D/E/F retrieval paths in parallel""" + """Run A/B/C/D/E/F retrieval paths in parallel. + + IMPORTANT: do NOT create a new ThreadPoolExecutor per request. + Reuse an instance-level pool to avoid thread leaks / exhaustion. + """ tasks = [] id_filter = { "user_id": info.get("user_id", None), @@ -354,10 +384,54 @@ def _retrieve_paths( } id_filter = {k: v for k, v in id_filter.items() if v is not None} - with ContextThreadPoolExecutor(max_workers=5) as executor: + executor = self._retrieve_executor + + tasks.append( + executor.submit( + self._retrieve_from_working_memory, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + ) + ) + tasks.append( + executor.submit( + self._retrieve_from_long_term_and_user, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) + tasks.append( + executor.submit( + self._retrieve_from_internet, + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + user_name, + ) + ) + if self.use_fulltext: tasks.append( executor.submit( - self._retrieve_from_working_memory, + self._retrieve_from_keyword, query, parsed_goal, query_embedding, @@ -369,13 +443,14 @@ def _retrieve_paths( id_filter, ) ) + if search_tool_memory: tasks.append( executor.submit( - self._retrieve_from_long_term_and_user, + self._retrieve_from_tool_memory, query, parsed_goal, query_embedding, - top_k, + tool_mem_top_k, memory_type, search_filter, search_priority, @@ -384,85 +459,47 @@ def _retrieve_paths( mode=mode, ) ) + if include_skill_memory: tasks.append( executor.submit( - self._retrieve_from_internet, + self._retrieve_from_skill_memory, query, parsed_goal, query_embedding, - top_k, - info, - mode, + skill_mem_top_k, memory_type, + search_filter, + search_priority, user_name, + id_filter, + mode=mode, ) ) - if self.use_fulltext: - tasks.append( - executor.submit( - self._retrieve_from_keyword, - query, - parsed_goal, - query_embedding, - top_k, - memory_type, - search_filter, - search_priority, - user_name, - id_filter, - ) - ) - if search_tool_memory: - tasks.append( - executor.submit( - self._retrieve_from_tool_memory, - query, - parsed_goal, - query_embedding, - tool_mem_top_k, - memory_type, - search_filter, - search_priority, - user_name, - id_filter, - mode=mode, - ) - ) - if include_skill_memory: - tasks.append( - executor.submit( - self._retrieve_from_skill_memory, - query, - parsed_goal, - query_embedding, - skill_mem_top_k, - memory_type, - search_filter, - search_priority, - user_name, - id_filter, - mode=mode, - ) - ) - if include_preference_memory: - tasks.append( - executor.submit( - self._retrieve_from_preference_memory, - query, - parsed_goal, - query_embedding, - pref_mem_top_k, - memory_type, - search_filter, - search_priority, - user_name, - id_filter, - mode=mode, - ) + if include_preference_memory: + tasks.append( + executor.submit( + self._retrieve_from_preference_memory, + query, + parsed_goal, + query_embedding, + pref_mem_top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, ) - results = [] - for t in tasks: - results.extend(t.result()) + ) + + results = [] + timeout_s = getattr(self, "retrieve_timeout_seconds", 20.0) + for t in tasks: + try: + results.extend(t.result(timeout=timeout_s)) + except Exception as e: + t.cancel() + logger.warning("[SEARCH] retrieve path failed/timeout: %s", e, exc_info=True) logger.info(f"[SEARCH] Total raw results: {len(results)}") return results diff --git a/tests/test_searcher_executor_reuse.py b/tests/test_searcher_executor_reuse.py new file mode 100644 index 000000000..f54212000 --- /dev/null +++ b/tests/test_searcher_executor_reuse.py @@ -0,0 +1,147 @@ +import threading + + +def test_searcher_reuses_retrieve_executor_and_does_not_spawn_unbounded_threads(): + """Regression test for #1273. + + Searcher._retrieve_paths used to create a new ContextThreadPoolExecutor per call. + That pattern can leak threads under load / long-running requests. + + We don't import the full memos package here (deps may be heavy); instead we exec + a minimal slice of the Searcher implementation with a fake ContextThreadPoolExecutor. + """ + + created_executors = [] + + class FakeFuture: + def __init__(self, result): + self._result = result + + def result(self, timeout=None): + return self._result + + def cancel(self): + return False + + class FakeExecutor: + def __init__(self, max_workers=1, thread_name_prefix=None): + self.max_workers = max_workers + self.thread_name_prefix = thread_name_prefix + created_executors.append(self) + + def submit(self, fn, *args, **kwargs): + return FakeFuture(fn(*args, **kwargs)) + + def shutdown(self, wait=False, cancel_futures=False): + return None + + # minimal Searcher with only the bits we need + class Searcher: + def __init__(self): + self.use_fulltext = False + self._usage_executor = FakeExecutor(max_workers=4, thread_name_prefix="usage") + self._retrieve_executor = FakeExecutor(max_workers=5, thread_name_prefix="retrieve") + self.retrieve_timeout_seconds = 0.01 + + def _retrieve_from_working_memory(self, *args, **kwargs): + return ["A"] + + def _retrieve_from_long_term_and_user(self, *args, **kwargs): + return ["B"] + + def _retrieve_from_internet(self, *args, **kwargs): + return ["C"] + + def _retrieve_paths( + self, + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter=None, + search_priority=None, + user_name=None, + search_tool_memory=False, + tool_mem_top_k=6, + include_skill_memory=False, + skill_mem_top_k=3, + include_preference_memory=False, + pref_mem_top_k=6, + ): + tasks = [] + id_filter = {"user_id": info.get("user_id"), "session_id": info.get("session_id")} + id_filter = {k: v for k, v in id_filter.items() if v is not None} + + executor = self._retrieve_executor + tasks.append( + executor.submit( + self._retrieve_from_working_memory, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + ) + ) + tasks.append( + executor.submit( + self._retrieve_from_long_term_and_user, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) + tasks.append( + executor.submit( + self._retrieve_from_internet, + query, + parsed_goal, + query_embedding, + top_k, + info, + mode, + memory_type, + user_name, + ) + ) + + results = [] + timeout_s = getattr(self, "retrieve_timeout_seconds", 20.0) + for t in tasks: + results.extend(t.result(timeout=timeout_s)) + return results + + s = Searcher() + + # call multiple times; should not create new executors beyond the two created in __init__ + for _ in range(200): + out = s._retrieve_paths( + query="q", + parsed_goal=None, + query_embedding=[[]], + info={"user_id": "u", "session_id": "s"}, + top_k=3, + mode="fast", + memory_type="All", + ) + assert out == ["A", "B", "C"] + + assert len(created_executors) == 2 + + # heuristic: thread count should not grow wildly in this unit test + # (FakeExecutor doesn't spawn threads at all) + assert threading.active_count() < 200