diff --git a/src/memos/api/exceptions.py b/src/memos/api/exceptions.py index 2fd22ad52..10a14b4d1 100644 --- a/src/memos/api/exceptions.py +++ b/src/memos/api/exceptions.py @@ -1,5 +1,6 @@ import logging +from fastapi.exceptions import HTTPException, RequestValidationError from fastapi.requests import Request from fastapi.responses import JSONResponse @@ -10,9 +11,24 @@ class APIExceptionHandler: """Centralized exception handling for MemOS APIs.""" + @staticmethod + async def validation_error_handler(request: Request, exc: RequestValidationError): + """Handle request validation errors.""" + logger.error(f"Validation error: {exc.errors()}") + return JSONResponse( + status_code=422, + content={ + "code": 422, + "message": "Parameter validation error", + "detail": exc.errors(), + "data": None, + }, + ) + @staticmethod async def value_error_handler(request: Request, exc: ValueError): """Handle ValueError exceptions globally.""" + logger.error(f"ValueError: {exc}") return JSONResponse( status_code=400, content={"code": 400, "message": str(exc), "data": None}, @@ -21,8 +37,17 @@ async def value_error_handler(request: Request, exc: ValueError): @staticmethod async def global_exception_handler(request: Request, exc: Exception): """Handle all unhandled exceptions globally.""" - logger.exception("Unhandled error:") + logger.error(f"Exception: {exc}") return JSONResponse( status_code=500, content={"code": 500, "message": str(exc), "data": None}, ) + + @staticmethod + async def http_error_handler(request: Request, exc: HTTPException): + """Handle HTTP exceptions globally.""" + logger.error(f"HTTP error {exc.status_code}: {exc.detail}") + return JSONResponse( + status_code=exc.status_code, + content={"code": exc.status_code, "message": str(exc.detail), "data": None}, + ) diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index cb41428d4..2922ab3eb 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,6 +2,8 @@ Request context middleware for automatic trace_id injection. """ +import time + from collections.abc import Callable from starlette.middleware.base import BaseHTTPMiddleware @@ -38,8 +40,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id trace_id = extract_trace_id_from_headers(request) or generate_trace_id() + env = request.headers.get("x-env") + user_type = request.headers.get("x-user-type") + user_name = request.headers.get("x-user-name") + start_time = time.time() + # Create and set request context - context = RequestContext(trace_id=trace_id, api_path=request.url.path) + context = RequestContext( + trace_id=trace_id, + api_path=request.url.path, + env=env, + user_type=user_type, + user_name=user_name, + ) set_request_context(context) # Log request start with parameters @@ -49,15 +62,25 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: if request.query_params: params_log["query_params"] = dict(request.query_params) - logger.info(f"Request started: {request.method} {request.url.path}, {params_log}") + logger.info(f"Request started, params: {params_log}, headers: {request.headers}") # Process the request - response = await call_next(request) - - # Log request completion with output - logger.info(f"Request completed: {request.url.path}, status: {response.status_code}") - - # Add trace_id to response headers for debugging - response.headers["x-trace-id"] = trace_id + try: + response = await call_next(request) + end_time = time.time() + if response.status_code == 200: + logger.info( + f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + else: + logger.error( + f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + except Exception as e: + end_time = time.time() + logger.error( + f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + raise e return response diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index f50d3ad75..3ba12c1ce 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,7 +1,6 @@ import os import traceback -from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Any from fastapi import APIRouter, HTTPException @@ -22,6 +21,7 @@ from memos.configs.mem_scheduler import SchedulerConfigFactory from memos.configs.reranker import RerankerConfigFactory from memos.configs.vec_db import VectorDBConfigFactory +from memos.context.context import ContextThreadPoolExecutor from memos.embedders.factory import EmbedderFactory from memos.graph_dbs.factory import GraphStoreFactory from memos.llms.factory import LLMFactory @@ -370,7 +370,7 @@ def _search_pref(): ) return [_format_memory_item(data) for data in results] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_search_text) pref_future = executor.submit(_search_pref) text_formatted_memories = text_future.result() @@ -532,7 +532,7 @@ def _process_pref_mem() -> list[dict[str, str]]: for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) ] - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_process_text_mem) pref_future = executor.submit(_process_pref_mem) text_response_data = text_future.result() diff --git a/src/memos/api/server_api.py b/src/memos/api/server_api.py index 78e05ef85..24c67de48 100644 --- a/src/memos/api/server_api.py +++ b/src/memos/api/server_api.py @@ -1,6 +1,7 @@ import logging -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException +from fastapi.exceptions import RequestValidationError from memos.api.exceptions import APIExceptionHandler from memos.api.middleware.request_context import RequestContextMiddleware @@ -21,8 +22,13 @@ # Include routers app.include_router(server_router) -# Exception handlers +# Request validation failed +app.exception_handler(RequestValidationError)(APIExceptionHandler.validation_error_handler) +# Invalid business code parameters app.exception_handler(ValueError)(APIExceptionHandler.value_error_handler) +# Business layer manual exception +app.exception_handler(HTTPException)(APIExceptionHandler.http_error_handler) +# Fallback for unknown errors app.exception_handler(Exception)(APIExceptionHandler.global_exception_handler) diff --git a/src/memos/context/context.py b/src/memos/context/context.py index 4f54348fb..d6a0f3bf1 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -29,9 +29,19 @@ class RequestContext: This provides a Flask g-like object for FastAPI applications. """ - def __init__(self, trace_id: str | None = None, api_path: str | None = None): + def __init__( + self, + trace_id: str | None = None, + api_path: str | None = None, + env: str | None = None, + user_type: str | None = None, + user_name: str | None = None, + ): self.trace_id = trace_id or "trace-id" self.api_path = api_path + self.env = env + self.user_type = user_type + self.user_name = user_name self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any: return self._data.get(key, default) def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_") or name in ("trace_id", "api_path"): + if name.startswith("_") or name in ( + "trace_id", + "api_path", + "env", + "user_type", + "user_name", + ): super().__setattr__(name, value) else: if not hasattr(self, "_data"): @@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any: def to_dict(self) -> dict[str, Any]: """Convert context to dictionary.""" - return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()} + return { + "trace_id": self.trace_id, + "api_path": self.api_path, + "env": self.env, + "user_type": self.user_type, + "user_name": self.user_name, + "data": self._data.copy(), + } def set_request_context(context: RequestContext) -> None: @@ -93,6 +116,36 @@ def get_current_api_path() -> str | None: return None +def get_current_env() -> str | None: + """ + Get the current request's env. + """ + context = _request_context.get() + if context: + return context.get("env") + return "prod" + + +def get_current_user_type() -> str | None: + """ + Get the current request's user type. + """ + context = _request_context.get() + if context: + return context.get("user_type") + return "opensource" + + +def get_current_user_name() -> str | None: + """ + Get the current request's user name. + """ + context = _request_context.get() + if context: + return context.get("user_name") + return "memos" + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None: context_dict = _request_context.get() if context_dict: ctx = RequestContext( - trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path") + trace_id=context_dict.get("trace_id"), + api_path=context_dict.get("api_path"), + env=context_dict.get("env"), + user_type=context_dict.get("user_type"), + user_name=context_dict.get("user_name"), ) ctx._data = context_dict.get("data", {}).copy() return ctx @@ -141,6 +198,9 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs): self.main_trace_id = get_current_trace_id() self.main_api_path = get_current_api_path() + self.main_env = get_current_env() + self.main_user_type = get_current_user_type() + self.main_user_name = get_current_user_name() self.main_context = get_current_context() def run(self): @@ -148,7 +208,11 @@ def run(self): if self.main_context: # Copy the context data child_context = RequestContext( - trace_id=self.main_trace_id, api_path=self.main_context.api_path + trace_id=self.main_trace_id, + api_path=self.main_api_path, + env=self.main_env, + user_type=self.main_user_type, + user_name=self.main_user_name, ) child_context._data = self.main_context._data.copy() @@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) @@ -198,13 +271,22 @@ def map( """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index 72116cf05..fc51cf073 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -3,6 +3,11 @@ from memos.configs.embedder import UniversalAPIEmbedderConfig from memos.embedders.base import BaseEmbedder +from memos.log import get_logger +from memos.utils import timed + + +logger = get_logger(__name__) class UniversalAPIEmbedder(BaseEmbedder): @@ -19,14 +24,18 @@ def __init__(self, config: UniversalAPIEmbedderConfig): api_key=config.api_key, ) else: - raise ValueError(f"Unsupported provider: {self.provider}") + raise ValueError(f"Embeddings unsupported provider: {self.provider}") + @timed(log=True, log_prefix="EmbedderAPI") def embed(self, texts: list[str]) -> list[list[float]]: if self.provider == "openai" or self.provider == "azure": - response = self.client.embeddings.create( - model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), - input=texts, - ) - return [r.embedding for r in response.data] + try: + response = self.client.embeddings.create( + model=getattr(self.config, "model_name_or_path", "text-embedding-3-large"), + input=texts, + ) + return [r.embedding for r in response.data] + except Exception as e: + raise Exception(f"Embeddings request ended with error: {e}") from e else: - raise ValueError(f"Unsupported provider: {self.provider}") + raise ValueError(f"Embeddings unsupported provider: {self.provider}") diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index b9bc2c8e5..88aef6d33 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1,18 +1,18 @@ import json -import time import random + from datetime import datetime from typing import Any, Literal import numpy as np - from memos.configs.graph_db import PolarDBGraphDBConfig from memos.dependency import require_python_package from memos.graph_dbs.base import BaseGraphDB from memos.log import get_logger from memos.utils import timed + logger = get_logger(__name__) # Graph database configuration @@ -200,31 +200,31 @@ def _create_graph(self): # Add embedding column if it doesn't exist (using JSONB for compatibility) try: cursor.execute(f""" - ALTER TABLE "{self.db_name}_graph"."Memory" + ALTER TABLE "{self.db_name}_graph"."Memory" ADD COLUMN IF NOT EXISTS embedding JSONB; """) - logger.info(f"Embedding column added to Memory table.") + logger.info("Embedding column added to Memory table.") except Exception as e: logger.warning(f"Failed to add embedding column: {e}") # Create indexes cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties + CREATE INDEX IF NOT EXISTS idx_memory_properties ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) # Create vector index for embedding field try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding + CREATE INDEX IF NOT EXISTS idx_memory_embedding ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100); """) - logger.info(f"Vector index created for Memory table.") + logger.info("Vector index created for Memory table.") except Exception as e: logger.warning(f"Vector index creation failed (might not be supported): {e}") - logger.info(f"Indexes created for Memory table.") + logger.info("Indexes created for Memory table.") except Exception as e: logger.error(f"Failed to create graph schema: {e}") @@ -246,20 +246,20 @@ def create_index( # Create indexes on the underlying PostgreSQL tables # Apache AGE stores data in regular PostgreSQL tables cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_properties + CREATE INDEX IF NOT EXISTS idx_memory_properties ON "{self.db_name}_graph"."Memory" USING GIN (properties); """) # Try to create vector index, but don't fail if it doesn't work try: cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_memory_embedding + CREATE INDEX IF NOT EXISTS idx_memory_embedding ON "{self.db_name}_graph"."Memory" USING ivfflat (embedding vector_cosine_ops); """) except Exception as ve: logger.warning(f"Vector index creation failed (might not be supported): {ve}") - logger.debug(f"Indexes created successfully.") + logger.debug("Indexes created successfully.") except Exception as e: logger.warning(f"Failed to create indexes: {e}") @@ -267,8 +267,8 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in """Get count of memory nodes by type.""" user_name = user_name if user_name else self._get_config_value("user_name") query = f""" - SELECT COUNT(*) - FROM "{self.db_name}_graph"."Memory" + SELECT COUNT(*) + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" @@ -290,8 +290,8 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: """Check if a node with given scope exists.""" user_name = user_name if user_name else self._get_config_value("user_name") query = f""" - SELECT id - FROM "{self.db_name}_graph"."Memory" + SELECT id + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype """ query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" @@ -327,15 +327,13 @@ def remove_oldest_memory( # Use actual OFFSET logic, consistent with nebular.py # First find IDs to delete, then delete them select_query = f""" - SELECT id FROM "{self.db_name}_graph"."Memory" + SELECT id FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"memory_type"'::agtype) = %s::agtype AND ag_catalog.agtype_access_operator(properties, '"user_name"'::agtype) = %s::agtype - ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC + ORDER BY ag_catalog.agtype_access_operator(properties, '"updated_at"'::agtype) DESC OFFSET %s """ select_params = [f'"{memory_type}"', f'"{user_name}"', keep_latest] - print(f"[remove_oldest_memory] Select query: {select_query}") - print(f"[remove_oldest_memory] Select params: {select_params}") try: with self.connection.cursor() as cursor: @@ -403,14 +401,14 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N # Build update query if embedding_vector is not None: query = f""" - UPDATE "{self.db_name}_graph"."Memory" + UPDATE "{self.db_name}_graph"."Memory" SET properties = %s, embedding = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [json.dumps(properties), json.dumps(embedding_vector), f'"{id}"'] else: query = f""" - UPDATE "{self.db_name}_graph"."Memory" + UPDATE "{self.db_name}_graph"."Memory" SET properties = %s WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ @@ -438,7 +436,7 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: user_name (str, optional): User name for filtering in non-multi-db mode """ query = f""" - DELETE FROM "{self.db_name}_graph"."Memory" + DELETE FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [f'"{id}"'] @@ -462,7 +460,7 @@ def create_extension(self): try: with self.connection.cursor() as cursor: # Ensure in the correct database context - cursor.execute(f"SELECT current_database();") + cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] print(f"Current database context: {current_db}") @@ -487,7 +485,7 @@ def create_graph(self): try: with self.connection.cursor() as cursor: cursor.execute(f""" - SELECT COUNT(*) FROM ag_catalog.ag_graph + SELECT COUNT(*) FROM ag_catalog.ag_graph WHERE name = '{self.db_name}_graph'; """) graph_exists = cursor.fetchone()[0] > 0 @@ -664,11 +662,11 @@ def edge_exists( # Prepare the match pattern with direction if direction == "OUTGOING": - pattern = f"(a:Memory)-[r]->(b:Memory)" + pattern = "(a:Memory)-[r]->(b:Memory)" elif direction == "INCOMING": - pattern = f"(a:Memory)<-[r]-(b:Memory)" + pattern = "(a:Memory)<-[r]-(b:Memory)" elif direction == "ANY": - pattern = f"(a:Memory)-[r]-(b:Memory)" + pattern = "(a:Memory)-[r]-(b:Memory)" else: raise ValueError( f"Invalid direction: {direction}. Must be 'OUTGOING', 'INCOMING', or 'ANY'." @@ -720,7 +718,7 @@ def format_param_value(value: str) -> str: query = f""" SELECT {select_fields} - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE ag_catalog.agtype_access_operator(properties, '"id"'::agtype) = %s::agtype """ params = [format_param_value(id)] @@ -806,7 +804,7 @@ def get_nodes( query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE ({where_clause}) """ @@ -893,15 +891,15 @@ def get_edges_old( # Create indexes cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_source + CREATE INDEX IF NOT EXISTS idx_edges_source ON "{self.db_name}_graph"."Edges" (source_id); """) cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_target + CREATE INDEX IF NOT EXISTS idx_edges_target ON "{self.db_name}_graph"."Edges" (target_id); """) cursor.execute(f""" - CREATE INDEX IF NOT EXISTS idx_edges_type + CREATE INDEX IF NOT EXISTS idx_edges_type ON "{self.db_name}_graph"."Edges" (edge_type); """) except Exception as e: @@ -998,7 +996,7 @@ def get_neighbors_by_tag_old( # Get all candidate nodes query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ @@ -1061,7 +1059,7 @@ def get_children_with_embeddings( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (p:Memory)-[r:PARENT]->(c:Memory) - WHERE p.id = '{id}' {where_user} + WHERE p.id = '{id}' {where_user} RETURN id(c) as cid, c.id AS id, c.memory AS memory $$) as (cid agtype, id agtype, memory agtype) ) @@ -1518,7 +1516,7 @@ def get_grouped_counts1( MATCH (n:Memory) {where_clause} RETURN {group_fields_cypher}, COUNT(n) AS count1 - $$ ) as ({group_fields_cypher_polardb}, count1 agtype); + $$ ) as ({group_fields_cypher_polardb}, count1 agtype); """ print("get_grouped_counts:" + query) try: @@ -1673,8 +1671,8 @@ def clear(self, user_name: str | None = None) -> None: try: query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ - MATCH (n:Memory) - WHERE n.user_name = '{user_name}' + MATCH (n:Memory) + WHERE n.user_name = '{user_name}' DETACH DELETE n $$) AS (result agtype) """ @@ -1765,7 +1763,7 @@ def export_graph( SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (a:Memory)-[r]->(b:Memory) WHERE a.user_name = '{user_name}' AND b.user_name = '{user_name}' - RETURN a.id AS source, b.id AS target, type(r) as edge + RETURN a.id AS source, b.id AS target, type(r) as edge $$) AS (source agtype, target agtype, edge agtype) """ @@ -1840,7 +1838,7 @@ def count_nodes(self, scope: str, user_name: str | None = None) -> int: query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) - WHERE n.memory_type = '{scope}' + WHERE n.memory_type = '{scope}' AND n.user_name = '{user_name}' RETURN count(n) $$) AS (count agtype) @@ -1879,8 +1877,8 @@ def get_all_memory_items( LIMIT 100 $$) AS (id1 agtype,n agtype) ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -1976,8 +1974,8 @@ def get_all_memory_items_old( LIMIT 100 $$) AS (id1 agtype,n agtype) ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -2144,8 +2142,8 @@ def get_structure_optimization_candidates( WITH t as ( {cypher_query} ) - SELECT - m.embedding, + SELECT + m.embedding, t.n FROM t, {self.db_name}_graph."Memory" m @@ -2358,7 +2356,7 @@ def add_node( with self.connection.cursor() as cursor: # Delete existing record first (if any) delete_query = f""" - DELETE FROM {self.db_name}_graph."Memory" + DELETE FROM {self.db_name}_graph."Memory" WHERE id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, %s::text::cstring) """ cursor.execute(delete_query, (id,)) @@ -2493,7 +2491,7 @@ def get_neighbors_by_tag( # Fetch all candidate nodes query = f""" SELECT id, properties, embedding - FROM "{self.db_name}_graph"."Memory" + FROM "{self.db_name}_graph"."Memory" WHERE {where_clause} """ @@ -2769,13 +2767,13 @@ def get_edges( user_name = user_name if user_name else self._get_config_value("user_name") if direction == "OUTGOING": - pattern = f"(a:Memory)-[r]->(b:Memory)" + pattern = "(a:Memory)-[r]->(b:Memory)" where_clause = f"a.id = '{id}'" elif direction == "INCOMING": - pattern = f"(a:Memory)<-[r]-(b:Memory)" + pattern = "(a:Memory)<-[r]-(b:Memory)" where_clause = f"a.id = '{id}'" elif direction == "ANY": - pattern = f"(a:Memory)-[r]-(b:Memory)" + pattern = "(a:Memory)-[r]-(b:Memory)" where_clause = f"a.id = '{id}' OR b.id = '{id}'" else: raise ValueError("Invalid direction. Must be 'OUTGOING', 'INCOMING', or 'ANY'.") diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 698bc3265..ca1df5c1f 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -11,6 +11,7 @@ from memos.llms.utils import remove_thinking_tags from memos.log import get_logger from memos.types import MessageList +from memos.utils import timed logger = get_logger(__name__) @@ -56,6 +57,7 @@ def clear_cache(cls): cls._instances.clear() logger.info("OpenAI LLM instance cache cleared") + @timed(log=True, log_prefix="OpenAI LLM") def generate(self, messages: MessageList) -> str: """Generate a response from OpenAI LLM.""" response = self.client.chat.completions.create( @@ -73,6 +75,7 @@ def generate(self, messages: MessageList) -> str: else: return response_content + @timed(log=True, log_prefix="OpenAI LLM") def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" response = self.client.chat.completions.create( diff --git a/src/memos/log.py b/src/memos/log.py index 339d13f26..2a538fdde 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -14,7 +14,13 @@ from dotenv import load_dotenv from memos import settings -from memos.context.context import get_current_api_path, get_current_trace_id +from memos.context.context import ( + get_current_api_path, + get_current_env, + get_current_trace_id, + get_current_user_name, + get_current_user_type, +) # Load environment variables @@ -34,15 +40,22 @@ def _setup_logfile() -> Path: return logfile -class TraceIDFilter(logging.Filter): - """add trace_id to the log record""" +class ContextFilter(logging.Filter): + """add context to the log record""" def filter(self, record): try: trace_id = get_current_trace_id() record.trace_id = trace_id if trace_id else "trace-id" + record.env = get_current_env() + record.user_type = get_current_user_type() + record.user_name = get_current_user_name() + record.api_path = get_current_api_path() except Exception: record.trace_id = "trace-id" + record.env = "prod" + record.user_type = "normal" + record.user_name = "unknown" return True @@ -86,13 +99,24 @@ def emit(self, record): try: trace_id = get_current_trace_id() or "trace-id" api_path = get_current_api_path() + env = get_current_env() + user_type = get_current_user_type() + user_name = get_current_user_name() if api_path is not None: - self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path) + self._executor.submit( + self._send_log_sync, + record.getMessage(), + trace_id, + api_path, + env, + user_type, + user_name, + ) except Exception as e: if not self._is_shutting_down.is_set(): print(f"Error sending log: {e}") - def _send_log_sync(self, message, trace_id, api_path): + def _send_log_sync(self, message, trace_id, api_path, env, user_type, user_name): """Send log message synchronously in a separate thread""" try: logger_url = os.getenv("CUSTOM_LOGGER_URL") @@ -104,6 +128,9 @@ def _send_log_sync(self, message, trace_id, api_path): "trace_id": trace_id, "action": api_path, "current_time": round(time.time(), 3), + "env": env, + "user_type": user_type, + "user_name": user_name, } # Add auth token if exists @@ -145,26 +172,26 @@ def close(self): "disable_existing_loggers": False, "formatters": { "standard": { - "format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(asctime)s | %(trace_id)s | path=%(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "no_datetime": { - "format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(trace_id)s | path=%(api_path)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "simplified": { - "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" + "format": "%(asctime)s | %(trace_id)s | path=%(api_path)s | % %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" }, }, "filters": { "package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX}, - "trace_id_filter": {"()": "memos.log.TraceIDFilter"}, + "context_filter": {"()": "memos.log.ContextFilter"}, }, "handlers": { "console": { - "level": selected_log_level, + "level": "DEBUG", "class": "logging.StreamHandler", "stream": stdout, "formatter": "no_datetime", - "filters": ["package_tree_filter", "trace_id_filter"], + "filters": ["package_tree_filter", "context_filter"], }, "file": { "level": "DEBUG", @@ -173,7 +200,7 @@ def close(self): "maxBytes": 1024**2 * 10, "backupCount": 10, "formatter": "standard", - "filters": ["trace_id_filter"], + "filters": ["context_filter"], }, "custom_logger": { "level": "INFO", diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index ec8a673d7..939b0c68d 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -2,13 +2,13 @@ import os import time -from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Lock from typing import Any, Literal from memos.configs.mem_os import MOSConfig +from memos.context.context import ContextThreadPoolExecutor from memos.llms.factory import LLMFactory from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube @@ -665,7 +665,7 @@ def search_preference_memory(cube_id, cube): return None # Execute both search functions in parallel - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(search_textual_memory, mem_cube_id, mem_cube) pref_future = executor.submit(search_preference_memory, mem_cube_id, mem_cube) @@ -824,7 +824,7 @@ def process_preference_memory(): self.mem_scheduler.submit_messages(messages=[message_item]) # Execute both memory processing functions in parallel - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(process_textual_memory) pref_future = executor.submit(process_preference_memory) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index d84ebb242..434cef3e9 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -3,6 +3,7 @@ import traceback from memos.configs.mem_scheduler import GeneralSchedulerConfig +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler @@ -281,7 +282,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing mem_read message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: @@ -413,7 +414,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing mem_read message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: @@ -506,7 +507,7 @@ def process_message(message: ScheduleMessageItem): except Exception as e: logger.error(f"Error processing pref_add message: {e}", exc_info=True) - with concurrent.futures.ThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: + with ContextThreadPoolExecutor(max_workers=min(8, len(messages))) as executor: futures = [executor.submit(process_message, msg) for msg in messages] for future in concurrent.futures.as_completed(futures): try: diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index 390f048ef..eb284cd6d 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -1,9 +1,10 @@ import json from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import ( @@ -162,7 +163,7 @@ def execute_op(op): self.vector_db.delete(collection_name, [op["target_id"]]) return None - with ThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: + with ContextThreadPoolExecutor(max_workers=min(len(rsp["trace"]), 5)) as executor: future_to_op = {executor.submit(execute_op, op): op for op in rsp["trace"]} added_ids = [] for future in as_completed(future_to_op): @@ -263,7 +264,7 @@ def add( return [] added_ids = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: + with ContextThreadPoolExecutor(max_workers=min(max_workers, len(memories))) as executor: future_to_memory = { executor.submit(self._process_single_memory, memory): memory for memory in memories } diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 460b31f4f..41d90d10e 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -2,10 +2,11 @@ import uuid from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import as_completed from datetime import datetime from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.prefer_text_memory.spliter import Splitter @@ -150,7 +151,7 @@ def extract( return [] memories = [] - with ThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: + with ContextThreadPoolExecutor(max_workers=min(max_workers, len(chunks))) as executor: futures = { executor.submit(self._process_single_chunk_explicit, chunk, msg_type, info): ( "explicit", diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 7f70bac3b..807a8b55e 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from concurrent.futures import ThreadPoolExecutor from typing import Any +from memos.context.context import ContextThreadPoolExecutor from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem @@ -42,7 +42,7 @@ def retrieve( query_embedding = query_embeddings[0] # Get the first (and only) embedding # Use thread pool to parallelize the searches - with ThreadPoolExecutor(max_workers=2) as executor: + with ContextThreadPoolExecutor(max_workers=2) as executor: # Submit all search tasks future_explicit = executor.submit( self.vector_db.search, query_embedding, "explicit_preference", top_k * 2, info diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 2c423e6b6..41011df14 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -9,10 +9,10 @@ import requests from memos.log import get_logger +from memos.utils import timed from .base import BaseReranker from .concat import concat_original_source -from memos.utils import timed logger = get_logger(__name__) @@ -119,7 +119,7 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() - @timed + @timed(log=True, log_prefix="RerankerAPI") def rerank( self, query: str, diff --git a/src/memos/reranker/http_bge_strategy.py b/src/memos/reranker/http_bge_strategy.py index 8cbf633a6..b0567698c 100644 --- a/src/memos/reranker/http_bge_strategy.py +++ b/src/memos/reranker/http_bge_strategy.py @@ -10,6 +10,7 @@ from memos.log import get_logger from memos.reranker.strategies import RerankerStrategyFactory +from memos.utils import timed from .base import BaseReranker @@ -119,6 +120,7 @@ def __init__( self._warned_missing_keys: set[str] = set() self.reranker_strategy = RerankerStrategyFactory.from_config(reranker_strategy) + @timed(log=True, log_prefix="RerankerStrategy") def rerank( self, query: str, diff --git a/src/memos/utils.py b/src/memos/utils.py index 6a1d42558..08934ed34 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -6,14 +6,24 @@ logger = get_logger(__name__) -def timed(func): - """Decorator to measure and log time of retrieval steps.""" +def timed(func=None, *, log=False, log_prefix=""): + """Decorator to measure and optionally log time of retrieval steps. - def wrapper(*args, **kwargs): - start = time.perf_counter() - result = func(*args, **kwargs) - elapsed = time.perf_counter() - start - logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") - return result + Can be used as @timed or @timed(log=True) + """ - return wrapper + def decorator(fn): + def wrapper(*args, **kwargs): + start = time.perf_counter() + result = fn(*args, **kwargs) + elapsed = time.perf_counter() - start + if log: + logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed:.2f} seconds") + return result + + return wrapper + + # Handle both @timed and @timed(log=True) cases + if func is None: + return decorator + return decorator(func)