Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions src/memos/api/middleware/request_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
Request context middleware for automatic trace_id injection.
"""

import time

from collections.abc import Callable

from fastapi.responses import StreamingResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response
Expand All @@ -13,6 +16,20 @@
from memos.context.context import RequestContext, generate_trace_id, set_request_context


async def _tee_stream(
original: StreamingResponse,
) -> StreamingResponse:
chunks = []

async for chunk in original.body_iterator:
chunks.append(chunk)
yield chunk

body_str = "".join(chunks).decode("utf-8", errors="replace")

logger.info(f"Response content: {body_str}")


logger = memos.log.get_logger(__name__)


Expand All @@ -38,8 +55,19 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
# Extract or generate trace_id
trace_id = extract_trace_id_from_headers(request) or generate_trace_id()

env = request.headers.get("x-env")
user_type = request.headers.get("x-user-type")
user_name = request.headers.get("x-user-name")
start_time = time.time()

# Create and set request context
context = RequestContext(trace_id=trace_id, api_path=request.url.path)
context = RequestContext(
trace_id=trace_id,
api_path=request.url.path,
env=env,
user_type=user_type,
user_name=user_name,
)
set_request_context(context)

# Log request start with parameters
Expand All @@ -49,15 +77,26 @@ async def dispatch(self, request: Request, call_next: Callable) -> Response:
if request.query_params:
params_log["query_params"] = dict(request.query_params)

logger.info(f"Request started: {request.method} {request.url.path}, {params_log}")
logger.info(f"Request started, params: {params_log}, headers: {request.headers}")

# Process the request
response = await call_next(request)

# Log request completion with output
logger.info(f"Request completed: {request.url.path}, status: {response.status_code}")

# Add trace_id to response headers for debugging
response.headers["x-trace-id"] = trace_id
try:
response = await call_next(request)
end_time = time.time()

if response.status_code == 200:
logger.info(
f"Request completed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
else:
logger.error(
f"Request Failed: {request.url.path}, status: {response.status_code}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
except Exception as e:
end_time = time.time()
logger.error(
f"Request Exception Error: {e}, cost: {(end_time - start_time) * 1000:.2f}ms"
)
raise e

return response
2 changes: 2 additions & 0 deletions src/memos/api/routers/server_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def search_memories(search_req: APISearchRequest):
}
)

logger.info(f"search_memories response data: {memories_result}")
return SearchResponse(
message="Search completed successfully",
data=memories_result,
Expand Down Expand Up @@ -285,6 +286,7 @@ def add_memories(add_req: APIADDRequest):
}
for memory_id, memory in zip(mem_id_list, flattened_memories, strict=False)
]
logger.info(f"add_memories response data: {response_data}")
return MemoryResponse(
message="Memory added successfully",
data=response_data,
Expand Down
96 changes: 89 additions & 7 deletions src/memos/context/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,19 @@ class RequestContext:
This provides a Flask g-like object for FastAPI applications.
"""

def __init__(self, trace_id: str | None = None, api_path: str | None = None):
def __init__(
self,
trace_id: str | None = None,
api_path: str | None = None,
env: str | None = None,
user_type: str | None = None,
user_name: str | None = None,
):
self.trace_id = trace_id or "trace-id"
self.api_path = api_path
self.env = env
self.user_type = user_type
self.user_name = user_name
self._data: dict[str, Any] = {}

def set(self, key: str, value: Any) -> None:
Expand All @@ -43,7 +53,13 @@ def get(self, key: str, default: Any | None = None) -> Any:
return self._data.get(key, default)

def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in ("trace_id", "api_path"):
if name.startswith("_") or name in (
"trace_id",
"api_path",
"env",
"user_type",
"user_name",
):
super().__setattr__(name, value)
else:
if not hasattr(self, "_data"):
Expand All @@ -58,7 +74,14 @@ def __getattr__(self, name: str) -> Any:

def to_dict(self) -> dict[str, Any]:
"""Convert context to dictionary."""
return {"trace_id": self.trace_id, "api_path": self.api_path, "data": self._data.copy()}
return {
"trace_id": self.trace_id,
"api_path": self.api_path,
"env": self.env,
"user_type": self.user_type,
"user_name": self.user_name,
"data": self._data.copy(),
}


def set_request_context(context: RequestContext) -> None:
Expand Down Expand Up @@ -93,6 +116,36 @@ def get_current_api_path() -> str | None:
return None


def get_current_env() -> str | None:
"""
Get the current request's env.
"""
context = _request_context.get()
if context:
return context.get("env")
return "prod"


def get_current_user_type() -> str | None:
"""
Get the current request's user type.
"""
context = _request_context.get()
if context:
return context.get("user_type")
return "opensource"


def get_current_user_name() -> str | None:
"""
Get the current request's user name.
"""
context = _request_context.get()
if context:
return context.get("user_name")
return "memos"


def get_current_context() -> RequestContext | None:
"""
Get the current request context.
Expand All @@ -103,7 +156,11 @@ def get_current_context() -> RequestContext | None:
context_dict = _request_context.get()
if context_dict:
ctx = RequestContext(
trace_id=context_dict.get("trace_id"), api_path=context_dict.get("api_path")
trace_id=context_dict.get("trace_id"),
api_path=context_dict.get("api_path"),
env=context_dict.get("env"),
user_type=context_dict.get("user_type"),
user_name=context_dict.get("user_name"),
)
ctx._data = context_dict.get("data", {}).copy()
return ctx
Expand Down Expand Up @@ -141,14 +198,21 @@ def __init__(self, target, args=(), kwargs=None, **thread_kwargs):

self.main_trace_id = get_current_trace_id()
self.main_api_path = get_current_api_path()
self.main_env = get_current_env()
self.main_user_type = get_current_user_type()
self.main_user_name = get_current_user_name()
self.main_context = get_current_context()

def run(self):
# Create a new RequestContext with the main thread's trace_id
if self.main_context:
# Copy the context data
child_context = RequestContext(
trace_id=self.main_trace_id, api_path=self.main_context.api_path
trace_id=self.main_trace_id,
api_path=self.main_api_path,
env=self.main_env,
user_type=self.main_user_type,
user_name=self.main_user_name,
)
child_context._data = self.main_context._data.copy()

Expand All @@ -171,13 +235,22 @@ def submit(self, fn: Callable[..., T], *args: Any, **kwargs: Any) -> Any:
"""
main_trace_id = get_current_trace_id()
main_api_path = get_current_api_path()
main_env = get_current_env()
main_user_type = get_current_user_type()
main_user_name = get_current_user_name()
main_context = get_current_context()

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if main_context:
# Create and set new context in worker thread
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
child_context = RequestContext(
trace_id=main_trace_id,
api_path=main_api_path,
env=main_env,
user_type=main_user_type,
user_name=main_user_name,
)
child_context._data = main_context._data.copy()
set_request_context(child_context)

Expand All @@ -198,13 +271,22 @@ def map(
"""
main_trace_id = get_current_trace_id()
main_api_path = get_current_api_path()
main_env = get_current_env()
main_user_type = get_current_user_type()
main_user_name = get_current_user_name()
main_context = get_current_context()

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if main_context:
# Create and set new context in worker thread
child_context = RequestContext(trace_id=main_trace_id, api_path=main_api_path)
child_context = RequestContext(
trace_id=main_trace_id,
api_path=main_api_path,
env=main_env,
user_type=main_user_type,
user_name=main_user_name,
)
child_context._data = main_context._data.copy()
set_request_context(child_context)

Expand Down
49 changes: 38 additions & 11 deletions src/memos/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
from dotenv import load_dotenv

from memos import settings
from memos.context.context import get_current_api_path, get_current_trace_id
from memos.context.context import (
get_current_api_path,
get_current_env,
get_current_trace_id,
get_current_user_name,
get_current_user_type,
)


# Load environment variables
Expand All @@ -34,15 +40,22 @@ def _setup_logfile() -> Path:
return logfile


class TraceIDFilter(logging.Filter):
"""add trace_id to the log record"""
class ContextFilter(logging.Filter):
"""add context to the log record"""

def filter(self, record):
try:
trace_id = get_current_trace_id()
record.trace_id = trace_id if trace_id else "trace-id"
record.env = get_current_env()
record.user_type = get_current_user_type()
record.user_name = get_current_user_name()
record.api_path = get_current_api_path()
except Exception:
record.trace_id = "trace-id"
record.env = "prod"
record.user_type = "normal"
record.user_name = "unknown"
return True


Expand Down Expand Up @@ -86,13 +99,24 @@ def emit(self, record):
try:
trace_id = get_current_trace_id() or "trace-id"
api_path = get_current_api_path()
env = get_current_env()
user_type = get_current_user_type()
user_name = get_current_user_name()
if api_path is not None:
self._executor.submit(self._send_log_sync, record.getMessage(), trace_id, api_path)
self._executor.submit(
self._send_log_sync,
record.getMessage(),
trace_id,
api_path,
env,
user_type,
user_name,
)
except Exception as e:
if not self._is_shutting_down.is_set():
print(f"Error sending log: {e}")

def _send_log_sync(self, message, trace_id, api_path):
def _send_log_sync(self, message, trace_id, api_path, env, user_type, user_name):
"""Send log message synchronously in a separate thread"""
try:
logger_url = os.getenv("CUSTOM_LOGGER_URL")
Expand All @@ -104,6 +128,9 @@ def _send_log_sync(self, message, trace_id, api_path):
"trace_id": trace_id,
"action": api_path,
"current_time": round(time.time(), 3),
"env": env,
"user_type": user_type,
"user_name": user_name,
}

# Add auth token if exists
Expand Down Expand Up @@ -145,26 +172,26 @@ def close(self):
"disable_existing_loggers": False,
"formatters": {
"standard": {
"format": "%(asctime)s [%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
"format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
},
"no_datetime": {
"format": "[%(trace_id)s] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
"format": "%(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(funcName)s - %(message)s"
},
"simplified": {
"format": "%(asctime)s | %(trace_id)s | %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
"format": "%(asctime)s | %(trace_id)s | %(api_path)s | env=%(env)s | user_type=%(user_type)s | user_name=%(user_name)s | % %(levelname)s | %(filename)s:%(lineno)d: %(funcName)s | %(message)s"
},
},
"filters": {
"package_tree_filter": {"()": "logging.Filter", "name": settings.LOG_FILTER_TREE_PREFIX},
"trace_id_filter": {"()": "memos.log.TraceIDFilter"},
"context_filter": {"()": "memos.log.ContextFilter"},
},
"handlers": {
"console": {
"level": selected_log_level,
"class": "logging.StreamHandler",
"stream": stdout,
"formatter": "no_datetime",
"filters": ["package_tree_filter", "trace_id_filter"],
"filters": ["package_tree_filter", "context_filter"],
},
"file": {
"level": "DEBUG",
Expand All @@ -173,7 +200,7 @@ def close(self):
"maxBytes": 1024**2 * 10,
"backupCount": 10,
"formatter": "standard",
"filters": ["trace_id_filter"],
"filters": ["context_filter"],
},
"custom_logger": {
"level": "INFO",
Expand Down
Loading