Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 53 additions & 79 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,14 +432,13 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int:
def remove_oldest_memory(
self, memory_type: str, keep_latest: int, user_name: str | None = None
) -> None:
"""
Remove all WorkingMemory nodes except the latest `keep_latest` entries.

Args:
memory_type (str): Memory type (e.g., 'WorkingMemory', 'LongTermMemory').
keep_latest (int): Number of latest WorkingMemory entries to keep.
user_name (str, optional): User name for filtering in non-multi-db mode
"""
start_time = time.perf_counter()
logger.info(
"remove_oldest_memory by memory_type:%s,keep_latest: %s,user_name:%s",
memory_type,
keep_latest,
user_name,
)
user_name = user_name if user_name else self._get_config_value("user_name")

# Use actual OFFSET logic, consistent with nebular.py
Expand All @@ -456,6 +455,9 @@ def remove_oldest_memory(
self.format_param_value(user_name),
keep_latest,
]
logger.info(
f"remove_oldest_memory by select_query:{select_query},select_params:{select_params}"
)
try:
with self._get_connection() as conn, conn.cursor() as cursor:
# Execute query to get IDs to delete
Expand All @@ -482,6 +484,8 @@ def remove_oldest_memory(
f"keeping {keep_latest} latest for user {user_name}, "
f"removed ids: {ids_to_delete}"
)
elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info("remove_oldest_memory internal took %.1f ms", elapsed)
except Exception as e:
logger.error(f"[remove_oldest_memory] Failed: {e}", exc_info=True)
raise
Expand Down Expand Up @@ -1840,9 +1844,8 @@ def search_by_embedding(
**kwargs,
) -> list[dict]:
logger.info(
"search_by_embedding user_name:%s,filter: %s, knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s",
"search_by_embedding by user_name:%s,knowledgebase_ids: %s,scope:%s,status:%s,search_filter:%s,filter:%s,knowledgebase_ids:%s,return_fields:%s",
user_name,
filter,
knowledgebase_ids,
scope,
status,
Expand Down Expand Up @@ -1895,20 +1898,21 @@ def search_by_embedding(
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

query = f"""
set hnsw.ef_search = 100;set hnsw.iterative_scan = relaxed_order;
WITH t AS (
SELECT id,
properties,
timeline,
ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id,
(1 - (embedding <=> %s::vector(1024))) AS scope
(embedding <=> %s::vector(1024)) AS scope_distance
FROM "{self.db_name}_graph"."Memory"
{where_clause}
ORDER BY scope DESC
ORDER BY scope_distance ASC
LIMIT {top_k}
)
SELECT *
SELECT *,(1 - scope_distance) AS scope
FROM t
WHERE scope > 0.1;
WHERE scope_distance < 0.9;
"""
vector_str = convert_to_vector(vector)
query = query.replace("%s::vector(1024)", f"'{vector_str}'::vector(1024)")
Expand Down Expand Up @@ -1953,7 +1957,7 @@ def search_by_embedding(
output.append(item)
elapsed_time = (time.perf_counter() - start_time) * 1000.0
logger.info(
"search_by_embedding query embedding completed time took %.1f ms", elapsed_time
"search_by_embedding query by embedding completed time took %.1f ms", elapsed_time
)
return output[:top_k]

Expand All @@ -1966,57 +1970,34 @@ def get_by_metadata(
knowledgebase_ids: list | None = None,
user_name_flag: bool = True,
) -> list[str]:
"""
Retrieve node IDs that match given metadata filters.
Supports exact match.

Args:
filters: List of filter dicts like:
[
{"field": "key", "op": "in", "value": ["A", "B"]},
{"field": "confidence", "op": ">=", "value": 80},
{"field": "tags", "op": "contains", "value": "AI"},
...
]
user_name (str, optional): User name for filtering in non-multi-db mode

Returns:
list[str]: Node IDs whose metadata match the filter conditions. (AND logic).
"""
start_time = time.perf_counter()
logger.info(
f" get_by_metadata user_name:{user_name},filter: {filter}, knowledgebase_ids: {knowledgebase_ids},filters:{filters}"
)

user_name = user_name if user_name else self._get_config_value("user_name")

# Build WHERE conditions for cypher query
where_conditions = []

for f in filters:
field = f["field"]
op = f.get("op", "=")
value = f["value"]

# Format value
if isinstance(value, str):
# Escape single quotes using backslash when inside $$ dollar-quoted strings
# In $$ delimiters, Cypher string literals can use \' to escape single quotes
escaped_str = value.replace("'", "\\'")
escaped_value = f"'{escaped_str}'"
elif isinstance(value, list):
# Handle list values - use double quotes for Cypher arrays
list_items = []
for v in value:
if isinstance(v, str):
# Escape double quotes in string values for Cypher
escaped_str = v.replace('"', '\\"')
list_items.append(f'"{escaped_str}"')
else:
list_items.append(str(v))
escaped_value = f"[{', '.join(list_items)}]"
else:
escaped_value = f"'{value}'" if isinstance(value, str) else str(value)
# Build WHERE conditions
if op == "=":
where_conditions.append(f"n.{field} = {escaped_value}")
elif op == "in":
Expand Down Expand Up @@ -2045,22 +2026,19 @@ def get_by_metadata(
knowledgebase_ids=knowledgebase_ids,
default_user_name=self._get_config_value("user_name"),
)
logger.info(f"[get_by_metadata] user_name_conditions: {user_name_conditions}")
logger.info(f"get_by_metadata user_name_conditions: {user_name_conditions}")

# Add user_name WHERE clause
if user_name_conditions:
if len(user_name_conditions) == 1:
where_conditions.append(user_name_conditions[0])
else:
where_conditions.append(f"({' OR '.join(user_name_conditions)})")

# Build filter conditions using common method
filter_where_clause = self._build_filter_conditions_cypher(filter)
logger.info(f"[get_by_metadata] filter_where_clause: {filter_where_clause}")
logger.info(f"get_by_metadata filter_where_clause: {filter_where_clause}")

where_str = " AND ".join(where_conditions) + filter_where_clause

# Use cypher query
cypher_query = f"""
SELECT * FROM cypher('{self.db_name}_graph', $$
MATCH (n:Memory)
Expand All @@ -2070,15 +2048,16 @@ def get_by_metadata(
"""

ids = []
logger.info(f"[get_by_metadata] cypher_query: {cypher_query}")
logger.info(f"get_by_metadata cypher_query: {cypher_query}")
try:
with self._get_connection() as conn, conn.cursor() as cursor:
cursor.execute(cypher_query)
results = cursor.fetchall()
ids = [str(item[0]).strip('"') for item in results]
except Exception as e:
logger.warning(f"Failed to get metadata: {e}, query is {cypher_query}")

elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info("get_by_metadata internal took %.1f ms", elapsed)
return ids

@timed
Expand Down Expand Up @@ -2165,25 +2144,19 @@ def get_grouped_counts(
params: dict[str, Any] | None = None,
user_name: str | None = None,
) -> list[dict[str, Any]]:
"""
Count nodes grouped by any fields.

Args:
group_fields (list[str]): Fields to group by, e.g., ["memory_type", "status"]
where_clause (str, optional): Extra WHERE condition. E.g.,
"WHERE n.status = 'activated'"
params (dict, optional): Parameters for WHERE clause.
user_name (str, optional): User name for filtering in non-multi-db mode

Returns:
list[dict]: e.g., [{ 'memory_type': 'WorkingMemory', 'status': 'active', 'count': 10 }, ...]
"""
start_time = time.perf_counter()
logger.info(
"get_grouped_counts by group_fields:%s,where_clause: %s,params:%s,user_name:%s",
group_fields,
where_clause,
params,
user_name,
)
if not group_fields:
raise ValueError("group_fields cannot be empty")

user_name = user_name if user_name else self._get_config_value("user_name")

# Build user clause
user_clause = f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype"
if where_clause:
where_clause = where_clause.strip()
Expand All @@ -2194,44 +2167,43 @@ def get_grouped_counts(
else:
where_clause = f"WHERE {user_clause}"

# Inline parameters if provided
if params and isinstance(params, dict):
for key, value in params.items():
# Handle different value types appropriately
if isinstance(value, str):
value = f"'{value}'"
where_clause = where_clause.replace(f"${key}", str(value))

# Handle user_name parameter in where_clause
if "user_name = %s" in where_clause:
where_clause = where_clause.replace(
"user_name = %s",
f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype",
)

# Build return fields and group by fields
return_fields = []
group_by_fields = []

cte_select_list = []
aliases = []
for field in group_fields:
alias = field.replace(".", "_")
return_fields.append(
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text AS {alias}"
)
group_by_fields.append(
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype)::text"
aliases.append(alias)
cte_select_list.append(
f"ag_catalog.agtype_access_operator(properties, '\"{field}\"'::agtype) AS {alias}"
)

# Full SQL query construction
outer_select = ", ".join(f"{a}::text" for a in aliases)
outer_group_by = ", ".join(aliases)
query = f"""
SELECT {", ".join(return_fields)}, COUNT(*) AS count
FROM "{self.db_name}_graph"."Memory"
{where_clause}
GROUP BY {", ".join(group_by_fields)}
WITH t AS (
SELECT {", ".join(cte_select_list)}
FROM "{self.db_name}_graph"."Memory"
{where_clause}
LIMIT 1000
)
SELECT {outer_select}, count(*) AS count
FROM t
GROUP BY {outer_group_by}
"""
logger.info(f"get_grouped_counts query:{query},params:{params}")

try:
with self._get_connection() as conn, conn.cursor() as cursor:
# Handle parameterized query
if params and isinstance(params, list):
cursor.execute(query, params)
else:
Expand All @@ -2250,6 +2222,8 @@ def get_grouped_counts(
count_value = row[-1] # Last column is count
output.append({**group_values, "count": int(count_value)})

elapsed = (time.perf_counter() - start_time) * 1000.0
logger.info("get_grouped_counts internal took %.1f ms", elapsed)
return output

except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion src/memos/mem_scheduler/schemas/general_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
DEFAULT_STUCK_THREAD_TOLERANCE = 10
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 200
DEFAULT_TOP_K = 5
DEFAULT_CONTEXT_WINDOW_SIZE = 5
DEFAULT_USE_REDIS_QUEUE = os.getenv("MEMSCHEDULER_USE_REDIS_QUEUE", "False").lower() == "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def add(

if mode == "sync":
self._cleanup_working_memory(user_name)
self._refresh_memory_size(user_name=user_name)

return added_ids

Expand Down
Loading