diff --git a/src/memos/api/middleware/request_context.py b/src/memos/api/middleware/request_context.py index cb41428d4..a7d0dc967 100644 --- a/src/memos/api/middleware/request_context.py +++ b/src/memos/api/middleware/request_context.py @@ -2,8 +2,11 @@ Request context middleware for automatic trace_id injection. """ +import time + from collections.abc import Callable +from fastapi.responses import StreamingResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response @@ -13,6 +16,20 @@ from memos.context.context import RequestContext, generate_trace_id, set_request_context +async def _tee_stream( + original: StreamingResponse, +) -> StreamingResponse: + chunks = [] + + async for chunk in original.body_iterator: + chunks.append(chunk) + yield chunk + + body_str = "".join(chunks).decode("utf-8", errors="replace") + + logger.info(f"Response content: {body_str}") + + logger = memos.log.get_logger(__name__) @@ -38,8 +55,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: # Extract or generate trace_id trace_id = extract_trace_id_from_headers(request) or generate_trace_id() + env = request.headers.get("x-env") + user_type = request.headers.get("x-user-type") + user_name = request.headers.get("x-user-name") + start_time = time.time() + # Create and set request context - context = RequestContext(trace_id=trace_id, api_path=request.url.path) + context = RequestContext( + trace_id=trace_id, + api_path=request.url.path, + env=env, + user_type=user_type, + user_name=user_name, + ) set_request_context(context) # Log request start with parameters @@ -49,15 +77,26 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response: if request.query_params: params_log["query_params"] = dict(request.query_params) - logger.info(f"Request started: {request.method} {request.url.path}, {params_log}") + logger.info(f"Request started, params: {params_log}, headers: {request.headers}") # Process the request - response = await call_next(request) - - # Log request completion with output - logger.info(f"Request completed: {request.url.path}, status: {response.status_code}") - - # Add trace_id to response headers for debugging - response.headers["x-trace-id"] = trace_id + try: + response = await call_next(request) + end_time = time.time() + + if response.status_code == 200: + logger.info( + f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + else: + logger.error( + f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + except Exception as e: + end_time = time.time() + logger.error( + f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms" + ) + raise e return response diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a332de583..8a2f0a968 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -237,6 +237,7 @@ def search_memories(search_req: APISearchRequest): } ) + logger.info(f"search_memories response data: {memories_result}") return SearchResponse( message="Search completed successfully", data=memories_result, @@ -285,6 +286,7 @@ def add_memories(add_req: APIADDRequest): } for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False) ] + logger.info(f"add_memories response data: {response_data}") return MemoryResponse( message="Memory added successfully", data=response_data, diff --git a/src/memos/context/context.py b/src/memos/context/context.py index 4f54348fb..d6a0f3bf1 100644 --- a/src/memos/context/context.py +++ b/src/memos/context/context.py @@ -29,9 +29,19 @@ class RequestContext: This provides a Flask g-like object for FastAPI applications. """ - def __init__(self, trace_id: str | None = None, api_path: str | None = None): + def __init__( + self, + trace_id: str | None = None, + api_path: str | None = None, + env: str | None = None, + user_type: str | None = None, + user_name: str | None = None, + ): self.trace_id = trace_id or "trace-id" self.api_path = api_path + self.env = env + self.user_type = user_type + self.user_name = user_name self._data: dict[str, Any] = {} def set(self, key: str, value: Any) -> None: @@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any: return self._data.get(key, default) def __setattr__(self, name: str, value: Any) -> None: - if name.startswith("_") or name in ("trace_id", "api_path"): + if name.startswith("_") or name in ( + "trace_id", + "api_path", + "env", + "user_type", + "user_name", + ): super().__setattr__(name, value) else: if not hasattr(self, "_data"): @@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any: def to_dict(self) -> dict[str, Any]: """Convert context to dictionary.""" - return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()} + return { + "trace_id": self.trace_id, + "api_path": self.api_path, + "env": self.env, + "user_type": self.user_type, + "user_name": self.user_name, + "data": self._data.copy(), + } def set_request_context(context: RequestContext) -> None: @@ -93,6 +116,36 @@ def get_current_api_path() -> str | None: return None +def get_current_env() -> str | None: + """ + Get the current request's env. + """ + context = _request_context.get() + if context: + return context.get("env") + return "prod" + + +def get_current_user_type() -> str | None: + """ + Get the current request's user type. + """ + context = _request_context.get() + if context: + return context.get("user_type") + return "opensource" + + +def get_current_user_name() -> str | None: + """ + Get the current request's user name. + """ + context = _request_context.get() + if context: + return context.get("user_name") + return "memos" + + def get_current_context() -> RequestContext | None: """ Get the current request context. @@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None: context_dict = _request_context.get() if context_dict: ctx = RequestContext( - trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path") + trace_id=context_dict.get("trace_id"), + api_path=context_dict.get("api_path"), + env=context_dict.get("env"), + user_type=context_dict.get("user_type"), + user_name=context_dict.get("user_name"), ) ctx._data = context_dict.get("data", {}).copy() return ctx @@ -141,6 +198,9 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs): self.main_trace_id = get_current_trace_id() self.main_api_path = get_current_api_path() + self.main_env = get_current_env() + self.main_user_type = get_current_user_type() + self.main_user_name = get_current_user_name() self.main_context = get_current_context() def run(self): @@ -148,7 +208,11 @@ def run(self): if self.main_context: # Copy the context data child_context = RequestContext( - trace_id=self.main_trace_id, api_path=self.main_context.api_path + trace_id=self.main_trace_id, + api_path=self.main_api_path, + env=self.main_env, + user_type=self.main_user_type, + user_name=self.main_user_name, ) child_context._data = self.main_context._data.copy() @@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any: """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) @@ -198,13 +271,22 @@ def map( """ main_trace_id = get_current_trace_id() main_api_path = get_current_api_path() + main_env = get_current_env() + main_user_type = get_current_user_type() + main_user_name = get_current_user_name() main_context = get_current_context() @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: if main_context: # Create and set new context in worker thread - child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path) + child_context = RequestContext( + trace_id=main_trace_id, + api_path=main_api_path, + env=main_env, + user_type=main_user_type, + user_name=main_user_name, + ) child_context._data = main_context._data.copy() set_request_context(child_context) diff --git a/src/memos/log.py b/src/memos/log.py index 339d13f26..d46bfa7f5 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -14,7 +14,13 @@ from dotenv import load_dotenv from memos import settings -from memos.context.context import get_current_api_path, get_current_trace_id +from memos.context.context import ( + get_current_api_path, + get_current_env, + get_current_trace_id, + get_current_user_name, + get_current_user_type, +) # Load environment variables @@ -34,15 +40,22 @@ def _setup_logfile() -> Path: return logfile -class TraceIDFilter(logging.Filter): - """add trace_id to the log record""" +class ContextFilter(logging.Filter): + """add context to the log record""" def filter(self, record): try: trace_id = get_current_trace_id() record.trace_id = trace_id if trace_id else "trace-id" + record.env = get_current_env() + record.user_type = get_current_user_type() + record.user_name = get_current_user_name() + record.api_path = get_current_api_path() except Exception: record.trace_id = "trace-id" + record.env = "prod" + record.user_type = "normal" + record.user_name = "unknown" return True @@ -86,13 +99,24 @@ def emit(self, record): try: trace_id = get_current_trace_id() or "trace-id" api_path = get_current_api_path() + env = get_current_env() + user_type = get_current_user_type() + user_name = get_current_user_name() if api_path is not None: - self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path) + self._executor.submit( + self._send_log_sync, + record.getMessage(), + trace_id, + api_path, + env, + user_type, + user_name, + ) except Exception as e: if not self._is_shutting_down.is_set(): print(f"Error sending log: {e}") - def _send_log_sync(self, message, trace_id, api_path): + def _send_log_sync(self, message, trace_id, api_path, env, user_type, user_name): """Send log message synchronously in a separate thread""" try: logger_url = os.getenv("CUSTOM_LOGGER_URL") @@ -104,6 +128,9 @@ def _send_log_sync(self, message, trace_id, api_path): "trace_id": trace_id, "action": api_path, "current_time": round(time.time(), 3), + "env": env, + "user_type": user_type, + "user_name": user_name, } # Add auth token if exists @@ -145,18 +172,18 @@ def close(self): "disable_existing_loggers": False, "formatters": { "standard": { - "format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "no_datetime": { - "format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" + "format": "%(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s" }, "simplified": { - "format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" + "format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | % %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s" }, }, "filters": { "package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX}, - "trace_id_filter": {"()": "memos.log.TraceIDFilter"}, + "context_filter": {"()": "memos.log.ContextFilter"}, }, "handlers": { "console": { @@ -164,7 +191,7 @@ def close(self): "class": "logging.StreamHandler", "stream": stdout, "formatter": "no_datetime", - "filters": ["package_tree_filter", "trace_id_filter"], + "filters": ["package_tree_filter", "context_filter"], }, "file": { "level": "DEBUG", @@ -173,7 +200,7 @@ def close(self): "maxBytes": 1024**2 * 10, "backupCount": 10, "formatter": "standard", - "filters": ["trace_id_filter"], + "filters": ["context_filter"], }, "custom_logger": { "level": "INFO", diff --git a/src/memos/utils.py b/src/memos/utils.py index 6a1d42558..5801bc2d2 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -1,5 +1,6 @@ import time +from memos import settings from memos.log import get_logger @@ -13,7 +14,8 @@ def wrapper(*args, **kwargs): start = time.perf_counter() result = func(*args, **kwargs) elapsed = time.perf_counter() - start - logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") + if settings.DEBUG: + logger.info(f"[TIMER] {func.__name__} took {elapsed:.2f} s") return result return wrapper