diff --git a/dynamiq/connections/connections.py b/dynamiq/connections/connections.py index 746aa0df6..83430b975 100644 --- a/dynamiq/connections/connections.py +++ b/dynamiq/connections/connections.py @@ -69,6 +69,9 @@ def to_dict(self, for_tracing: bool = False, **kwargs) -> dict: Returns: dict: A dictionary representation of the connection instance. """ + # Drop forwarded `include_secure_params` — has no effect here and + # `model_dump` would reject it as an unknown kwarg. + kwargs.pop("include_secure_params", None) if for_tracing: return {"id": self.id, "type": self.type} return self.model_dump(**kwargs) diff --git a/dynamiq/memory/long_term/__init__.py b/dynamiq/memory/long_term/__init__.py new file mode 100644 index 000000000..fd3be3e01 --- /dev/null +++ b/dynamiq/memory/long_term/__init__.py @@ -0,0 +1,14 @@ +from dynamiq.memory.long_term.base import LongTermMemoryBackend, LongTermMemoryError +from dynamiq.memory.long_term.long_term_memory import LongTermMemoryConfig +from dynamiq.memory.long_term.schemas import Fact +from dynamiq.memory.long_term.types import ForgetStatus, MemoryToolKind, RememberOutcome + +__all__ = [ + "Fact", + "ForgetStatus", + "LongTermMemoryBackend", + "LongTermMemoryConfig", + "LongTermMemoryError", + "MemoryToolKind", + "RememberOutcome", +] diff --git a/dynamiq/memory/long_term/backends/__init__.py b/dynamiq/memory/long_term/backends/__init__.py new file mode 100644 index 000000000..cdde9417e --- /dev/null +++ b/dynamiq/memory/long_term/backends/__init__.py @@ -0,0 +1,13 @@ +from dynamiq.memory.long_term.backends.in_memory import InMemoryLongTermMemoryBackend +from dynamiq.memory.long_term.backends.pgvector import PostgresLongTermMemoryBackend +from dynamiq.memory.long_term.backends.pinecone import PineconeLongTermMemoryBackend +from dynamiq.memory.long_term.backends.qdrant import QdrantLongTermMemoryBackend +from dynamiq.memory.long_term.backends.weaviate import WeaviateLongTermMemoryBackend + +__all__ = [ + "InMemoryLongTermMemoryBackend", + "PineconeLongTermMemoryBackend", + "PostgresLongTermMemoryBackend", + "QdrantLongTermMemoryBackend", + "WeaviateLongTermMemoryBackend", +] diff --git a/dynamiq/memory/long_term/backends/in_memory.py b/dynamiq/memory/long_term/backends/in_memory.py new file mode 100644 index 000000000..7fbd38532 --- /dev/null +++ b/dynamiq/memory/long_term/backends/in_memory.py @@ -0,0 +1,97 @@ +from datetime import datetime + +import numpy as np +from pydantic import PrivateAttr + +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + + +class InMemoryLongTermMemoryBackend(LongTermMemoryBackend): + """Dict + numpy-cosine backend. Loses data on restart.""" + + name: str = "in-memory-long-term-memory-backend" + + _facts: dict[str, Fact] = PrivateAttr(default_factory=dict) + _vectors: dict[str, list[float]] = PrivateAttr(default_factory=dict) + + def insert(self, fact: Fact, embedding: list[float]) -> None: + self._facts[fact.id] = fact + self._vectors[fact.id] = list(embedding) + + def get(self, fact_id: str) -> Fact | None: + return self._facts.get(fact_id) + + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + for fact in self._facts.values(): + if fact.user_id == user_id and fact.hash == content_hash: + return fact + return None + + def delete(self, fact_id: str) -> None: + self._facts.pop(fact_id, None) + self._vectors.pop(fact_id, None) + + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict, + updated_at: datetime, + ) -> None: + existing = self._facts.get(fact_id) + if existing is None: + return + self._facts[fact_id] = existing.model_copy( + update={"content": content, "hash": content_hash, "metadata": metadata, "updated_at": updated_at} + ) + self._vectors[fact_id] = list(embedding) + + def search( + self, *, query_embedding: list[float], + scope: dict[str, str], limit: int, + ) -> list[tuple[Fact, float]]: + if not self._facts: + return [] + + matched_facts = [f for f in self._facts.values() if _matches_scope(f, scope)] + if not matched_facts: + return [] + + matrix = np.asarray([self._vectors[f.id] for f in matched_facts], dtype=np.float64) + query = np.asarray(query_embedding, dtype=np.float64) + + # Cosine = (M @ q) / (||rows|| * ||q||); zero-norm rows fall back to 1 + # to avoid div-by-zero (the dot product is 0 anyway, so the score is 0). + row_norms = np.linalg.norm(matrix, axis=1) + row_norms[row_norms == 0] = 1.0 + query_norm = np.linalg.norm(query) or 1.0 + scores = (matrix @ query) / (row_norms * query_norm) + + k = min(limit, len(matched_facts)) + # argpartition gives the top-k unsorted; sort just that slice. + top_idx = np.argpartition(-scores, k - 1)[:k] + top_idx = top_idx[np.argsort(-scores[top_idx])] + return [(matched_facts[i], float(scores[i])) for i in top_idx] + + def list_by_scope( + self, scope: dict[str, str], limit: int = 100, + ) -> list[Fact]: + matched = [f for f in self._facts.values() if _matches_scope(f, scope)] + return matched[:limit] + + def delete_scope(self, scope: dict[str, str]) -> int: + to_delete = [fid for fid, f in self._facts.items() if _matches_scope(f, scope)] + for fid in to_delete: + self.delete(fid) + return len(to_delete) + + +def _matches_scope(fact: Fact, scope: dict[str, str]) -> bool: + for key, value in scope.items(): + if getattr(fact, key, None) != value: + return False + return True diff --git a/dynamiq/memory/long_term/backends/pgvector.py b/dynamiq/memory/long_term/backends/pgvector.py new file mode 100644 index 000000000..e8bbaac77 --- /dev/null +++ b/dynamiq/memory/long_term/backends/pgvector.py @@ -0,0 +1,238 @@ +from datetime import datetime +from typing import Any + +import psycopg +from pgvector.psycopg import register_vector +from psycopg.rows import dict_row +from psycopg.sql import SQL, Composed, Identifier +from psycopg.types.json import Jsonb +from pydantic import ConfigDict, Field, PrivateAttr + +from dynamiq.connections import PostgreSQL as PostgreSQLConnection +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +_CREATE_EXTENSION_SQL = SQL("CREATE EXTENSION IF NOT EXISTS vector") + +_CREATE_TABLE_TEMPLATE = SQL( + """ + CREATE TABLE IF NOT EXISTS {table} ( + id TEXT PRIMARY KEY, + content TEXT NOT NULL, + hash TEXT NOT NULL, + user_id TEXT NOT NULL, + metadata JSONB NOT NULL DEFAULT '{{}}'::jsonb, + embedding vector({dim}) NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL + ) + """ +) + +_CREATE_USER_ID_INDEX_TEMPLATE = SQL("CREATE INDEX IF NOT EXISTS {idx} ON {table} (user_id)") + +_CREATE_USER_HASH_INDEX_TEMPLATE = SQL("CREATE UNIQUE INDEX IF NOT EXISTS {idx} ON {table} (user_id, hash)") + + +def _scope_where_clause(scope: dict[str, str]) -> tuple[Composed, list]: + """Build a parameterised WHERE clause from a scope dict. + + Keys are interpolated as `Identifier` (safe); values stay as `%s` placeholders + for the driver — never an f-string substitution. + """ + if not scope: + return SQL("TRUE"), [] + clauses = [SQL("{key} = %s").format(key=Identifier(key)) for key in scope.keys()] + return SQL(" AND ").join(clauses), list(scope.values()) + + +def _row_to_fact(row) -> Fact: + return Fact( + id=row["id"], + content=row["content"], + hash=row["hash"], + user_id=row["user_id"], + metadata=row["metadata"] or {}, + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + +_FACT_COLUMNS = SQL("id, content, hash, user_id, metadata, created_at, updated_at") + + +class PostgresLongTermMemoryBackend(LongTermMemoryBackend): + """Long-term memory backend backed by Postgres + pgvector.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "postgres-long-term-memory-backend" + connection: PostgreSQLConnection = Field(default_factory=PostgreSQLConnection) + table_name: str = "user_facts" + dimension: int = 1536 + + _conn: psycopg.Connection | None = PrivateAttr(default=None) + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + return super().to_dict_exclude_params | {"_conn": True, "connection": True} + + def to_dict(self, include_secure_params: bool = False, for_tracing: bool = False, **kwargs) -> dict[str, Any]: + # super() re-adds the embedder; we add the connection on top. + data = super().to_dict(include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs) + data["connection"] = self.connection.to_dict( + for_tracing=for_tracing, include_secure_params=include_secure_params, **kwargs + ) + return data + + def model_post_init(self, __context) -> None: + self._conn = self.connection.connect() + self._conn.autocommit = True + with self._conn.cursor() as cur: + cur.execute(_CREATE_EXTENSION_SQL) + register_vector(self._conn) + + @property + def _table(self) -> Identifier: + """Return the table name wrapped as a safe SQL identifier.""" + return Identifier(self.table_name) + + def ensure_table(self) -> None: + """Create the facts table and indexes if absent. Safe to call repeatedly.""" + with self._conn.cursor() as cur: + cur.execute(_CREATE_EXTENSION_SQL) + cur.execute(_CREATE_TABLE_TEMPLATE.format(table=self._table, dim=SQL(str(self.dimension)))) + cur.execute( + _CREATE_USER_ID_INDEX_TEMPLATE.format( + idx=Identifier(f"{self.table_name}_user_id_idx"), + table=self._table, + ) + ) + cur.execute( + _CREATE_USER_HASH_INDEX_TEMPLATE.format( + idx=Identifier(f"{self.table_name}_user_hash_uidx"), + table=self._table, + ) + ) + + def recreate_table(self) -> None: + """Drop and re-create the facts table. Test-only helper.""" + with self._conn.cursor() as cur: + cur.execute(SQL("DROP TABLE IF EXISTS {table}").format(table=self._table)) + self.ensure_table() + + def drop_table(self) -> None: + """Drop the facts table if it exists. Test-only helper.""" + with self._conn.cursor() as cur: + cur.execute(SQL("DROP TABLE IF EXISTS {table}").format(table=self._table)) + + def insert(self, fact: Fact, embedding: list[float]) -> None: + with self._conn.cursor() as cur: + cur.execute( + SQL("INSERT INTO {table} ({cols}, embedding) " "VALUES (%s, %s, %s, %s, %s, %s, %s, %s)").format( + table=self._table, cols=_FACT_COLUMNS + ), + ( + fact.id, + fact.content, + fact.hash, + fact.user_id, + Jsonb(fact.metadata), + fact.created_at, + fact.updated_at, + embedding, + ), + ) + + def get(self, fact_id: str) -> Fact | None: + with self._conn.cursor(row_factory=dict_row) as cur: + cur.execute( + SQL("SELECT {cols} FROM {table} WHERE id = %s").format( + cols=_FACT_COLUMNS, + table=self._table, + ), + (fact_id,), + ) + row = cur.fetchone() + return _row_to_fact(row) if row else None + + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + with self._conn.cursor(row_factory=dict_row) as cur: + cur.execute( + SQL("SELECT {cols} FROM {table} WHERE user_id = %s AND hash = %s").format( + cols=_FACT_COLUMNS, table=self._table + ), + (user_id, content_hash), + ) + row = cur.fetchone() + return _row_to_fact(row) if row else None + + def delete(self, fact_id: str) -> None: + with self._conn.cursor() as cur: + cur.execute( + SQL("DELETE FROM {table} WHERE id = %s").format(table=self._table), + (fact_id,), + ) + + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict, + updated_at: datetime, + ) -> None: + with self._conn.cursor() as cur: + cur.execute( + SQL( + "UPDATE {table} SET content = %s, hash = %s, embedding = %s, " + "metadata = %s, updated_at = %s WHERE id = %s" + ).format(table=self._table), + (content, content_hash, embedding, Jsonb(metadata), updated_at, fact_id), + ) + + def search( + self, + *, + query_embedding: list[float], + scope: dict[str, str], + limit: int, + ) -> list[tuple[Fact, float]]: + where, params = _scope_where_clause(scope) + with self._conn.cursor(row_factory=dict_row) as cur: + cur.execute( + SQL( + "SELECT {cols}, 1 - (embedding <=> %s::vector) AS score " + "FROM {table} WHERE {where} " + "ORDER BY embedding <=> %s::vector LIMIT %s" + ).format(cols=_FACT_COLUMNS, table=self._table, where=where), + [query_embedding] + params + [query_embedding, limit], + ) + rows = cur.fetchall() + return [(_row_to_fact(row), float(row["score"])) for row in rows] + + def list_by_scope(self, scope: dict[str, str], limit: int = 100) -> list[Fact]: + where, params = _scope_where_clause(scope) + with self._conn.cursor(row_factory=dict_row) as cur: + cur.execute( + SQL("SELECT {cols} FROM {table} WHERE {where} " "ORDER BY created_at DESC LIMIT %s").format( + cols=_FACT_COLUMNS, table=self._table, where=where + ), + params + [limit], + ) + rows = cur.fetchall() + return [_row_to_fact(row) for row in rows] + + def delete_scope(self, scope: dict[str, str]) -> int: + where, params = _scope_where_clause(scope) + with self._conn.cursor() as cur: + cur.execute( + SQL("DELETE FROM {table} WHERE {where}").format( + table=self._table, + where=where, + ), + params, + ) + return cur.rowcount diff --git a/dynamiq/memory/long_term/backends/pinecone.py b/dynamiq/memory/long_term/backends/pinecone.py new file mode 100644 index 000000000..01ddc1a7e --- /dev/null +++ b/dynamiq/memory/long_term/backends/pinecone.py @@ -0,0 +1,212 @@ +import json +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from pydantic import ConfigDict, Field, PrivateAttr + +from dynamiq.connections import Pinecone as PineconeConnection +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +if TYPE_CHECKING: + from pinecone import Pinecone as PineconeClient + +# Pinecone metadata doesn't accept nested dicts; we JSON-encode the Fact's +# `metadata` into this single string field and decode on read. +_METADATA_JSON_KEY = "metadata_json" + + +def _scope_to_filter(scope: dict[str, str]) -> dict | None: + """Translate `{key: value, ...}` into a Pinecone metadata filter. + + Single-key scopes become a flat `{key: {"$eq": value}}`; multi-key scopes + are wrapped in `$and` since Pinecone treats sibling keys as implicit-AND + only when each is a leaf. + """ + if not scope: + return None + if len(scope) == 1: + (key, value), = scope.items() + return {key: {"$eq": value}} + return {"$and": [{key: {"$eq": value}} for key, value in scope.items()]} + + +def _fact_to_metadata(fact: Fact) -> dict[str, Any]: + return { + "fact_id": fact.id, + "content": fact.content, + "hash": fact.hash, + "user_id": fact.user_id, + _METADATA_JSON_KEY: json.dumps(fact.metadata or {}), + "created_at": fact.created_at.isoformat(), + "updated_at": fact.updated_at.isoformat(), + } + + +def _metadata_to_fact(meta: dict[str, Any]) -> Fact: + raw_meta = meta.get(_METADATA_JSON_KEY) or "{}" + return Fact( + id=meta["fact_id"], + content=meta["content"], + hash=meta["hash"], + user_id=meta["user_id"], + metadata=json.loads(raw_meta) if isinstance(raw_meta, str) else raw_meta, + created_at=datetime.fromisoformat(meta["created_at"]), + updated_at=datetime.fromisoformat(meta["updated_at"]), + ) + + +class PineconeLongTermMemoryBackend(LongTermMemoryBackend): + """Long-term memory backend backed by Pinecone. + + Facts are stored as Pinecone vectors keyed by the original `fact_id` (Pinecone + accepts arbitrary string ids). Fact payload lives in the vector's metadata; the + free-form `Fact.metadata` dict is JSON-encoded into a single string field to + avoid Pinecone's no-nested-dicts restriction. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "pinecone-long-term-memory-backend" + connection: PineconeConnection = Field(default_factory=PineconeConnection) + index_name: str = "user_facts" + namespace: str = "default" + dimension: int = 1536 + # Scroll/list pagination cap. Pinecone's max top_k is 10000 per query. + _LIST_PAGE_SIZE: int = 10_000 + + _client: "PineconeClient | None" = PrivateAttr(default=None) + _index: Any = PrivateAttr(default=None) + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + return super().to_dict_exclude_params | {"_client": True, "_index": True, "connection": True} + + def to_dict(self, include_secure_params: bool = False, for_tracing: bool = False, **kwargs) -> dict[str, Any]: + data = super().to_dict(include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs) + data["connection"] = self.connection.to_dict( + for_tracing=for_tracing, include_secure_params=include_secure_params, **kwargs + ) + return data + + def model_post_init(self, __context) -> None: + self._client = self.connection.connect() + self._index = self._client.Index(name=self.index_name) + + def insert(self, fact: Fact, embedding: list[float]) -> None: + self._index.upsert( + vectors=[{"id": fact.id, "values": list(embedding), "metadata": _fact_to_metadata(fact)}], + namespace=self.namespace, + ) + + def get(self, fact_id: str) -> Fact | None: + result = self._index.fetch(ids=[fact_id], namespace=self.namespace) + vectors = result.get("vectors") if isinstance(result, dict) else getattr(result, "vectors", {}) + if not vectors or fact_id not in vectors: + return None + vec = vectors[fact_id] + meta = vec["metadata"] if isinstance(vec, dict) else vec.metadata + return _metadata_to_fact(meta) + + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + # Pinecone's metadata-only filter goes through `query`; we send a zero + # vector since the score is irrelevant for a hash lookup. + result = self._index.query( + vector=[0.0] * self.dimension, + top_k=1, + namespace=self.namespace, + filter=_scope_to_filter({"user_id": user_id, "hash": content_hash}), + include_metadata=True, + ) + matches = result.get("matches") if isinstance(result, dict) else getattr(result, "matches", []) + if not matches: + return None + match = matches[0] + meta = match["metadata"] if isinstance(match, dict) else match.metadata + return _metadata_to_fact(meta) + + def delete(self, fact_id: str) -> None: + self._index.delete(ids=[fact_id], namespace=self.namespace) + + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict, + updated_at: datetime, + ) -> None: + existing = self.get(fact_id) + if existing is None: + return + new_fact = existing.model_copy( + update={"content": content, "hash": content_hash, "metadata": metadata, "updated_at": updated_at} + ) + # `upsert` overwrites both the vector and the metadata payload in one call. + self._index.upsert( + vectors=[{"id": fact_id, "values": list(embedding), "metadata": _fact_to_metadata(new_fact)}], + namespace=self.namespace, + ) + + def search( + self, + *, + query_embedding: list[float], + scope: dict[str, str], + limit: int, + ) -> list[tuple[Fact, float]]: + result = self._index.query( + vector=list(query_embedding), + top_k=limit, + namespace=self.namespace, + filter=_scope_to_filter(scope), + include_metadata=True, + ) + matches = result.get("matches") if isinstance(result, dict) else getattr(result, "matches", []) + out: list[tuple[Fact, float]] = [] + for match in matches: + meta = match["metadata"] if isinstance(match, dict) else match.metadata + score = match["score"] if isinstance(match, dict) else match.score + out.append((_metadata_to_fact(meta), float(score))) + return out + + def list_by_scope(self, scope: dict[str, str], limit: int = 100) -> list[Fact]: + # Pinecone has no "scan" primitive — the documented pattern is a query + # with a zero vector + filter. Capped at top_k=10000 (Pinecone's max). + top_k = min(max(limit, 1), self._LIST_PAGE_SIZE) + result = self._index.query( + vector=[0.0] * self.dimension, + top_k=top_k, + namespace=self.namespace, + filter=_scope_to_filter(scope), + include_metadata=True, + ) + matches = result.get("matches") if isinstance(result, dict) else getattr(result, "matches", []) + return [ + _metadata_to_fact(match["metadata"] if isinstance(match, dict) else match.metadata) for match in matches + ] + + def delete_scope(self, scope: dict[str, str]) -> int: + # Pinecone Serverless does NOT support delete-by-filter — only delete-by-id. + # And Pinecone's query API has no cursor, so we loop: query → delete the + # matched ids → query again until the page comes back empty. Without the + # loop, scopes with >`_LIST_PAGE_SIZE` facts (10k) would silently leak. + total = 0 + flt = _scope_to_filter(scope) + while True: + result = self._index.query( + vector=[0.0] * self.dimension, + top_k=self._LIST_PAGE_SIZE, + namespace=self.namespace, + filter=flt, + include_metadata=False, + ) + matches = result.get("matches") if isinstance(result, dict) else getattr(result, "matches", []) + ids = [match["id"] if isinstance(match, dict) else match.id for match in matches] + if not ids: + break + self._index.delete(ids=ids, namespace=self.namespace) + total += len(ids) + return total diff --git a/dynamiq/memory/long_term/backends/qdrant.py b/dynamiq/memory/long_term/backends/qdrant.py new file mode 100644 index 000000000..a8f5f06ce --- /dev/null +++ b/dynamiq/memory/long_term/backends/qdrant.py @@ -0,0 +1,246 @@ +import uuid +from datetime import datetime +from typing import Any + +from pydantic import ConfigDict, PrivateAttr +from qdrant_client import QdrantClient +from qdrant_client.http.models import ( + Distance, + FieldCondition, + Filter, + MatchValue, + PointIdsList, + PointStruct, + VectorParams, +) + +from dynamiq.connections import Qdrant as QdrantConnection +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +_UUID_NAMESPACE = uuid.UUID("00000000-0000-0000-0000-000000000000") + + +def _to_point_id(fact_id: str) -> str: + """Map an arbitrary `fact_id` string to a deterministic Qdrant UUID. + + Qdrant requires UUID or unsigned-int point IDs; the original `fact_id` + is kept in the payload so lookups round-trip. + """ + return uuid.uuid5(_UUID_NAMESPACE, fact_id).hex + + +def _scope_to_filter(scope: dict[str, str]) -> Filter | None: + if not scope: + return None + return Filter( + must=[ + FieldCondition(key=key, match=MatchValue(value=value)) + for key, value in scope.items() + ] + ) + + +def _fact_to_payload(fact: Fact) -> dict: + return { + "fact_id": fact.id, + "content": fact.content, + "hash": fact.hash, + "user_id": fact.user_id, + "metadata": fact.metadata, + "created_at": fact.created_at.isoformat(), + "updated_at": fact.updated_at.isoformat(), + } + + +def _payload_to_fact(payload: dict) -> Fact: + return Fact( + id=payload["fact_id"], + content=payload["content"], + hash=payload["hash"], + user_id=payload["user_id"], + metadata=payload.get("metadata", {}), + created_at=datetime.fromisoformat(payload["created_at"]), + updated_at=datetime.fromisoformat(payload["updated_at"]), + ) + + +class QdrantLongTermMemoryBackend(LongTermMemoryBackend): + """Long-term memory backend backed by Qdrant.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "qdrant-long-term-memory-backend" + connection: QdrantConnection + collection_name: str = "user_facts" + dimension: int = 1536 + + _client: QdrantClient | None = PrivateAttr(default=None) + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + return super().to_dict_exclude_params | {"_client": True, "connection": True} + + def to_dict(self, include_secure_params: bool = False, for_tracing: bool = False, **kwargs) -> dict[str, Any]: + # super() re-adds the embedder; we add the connection on top. + data = super().to_dict(include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs) + data["connection"] = self.connection.to_dict( + for_tracing=for_tracing, include_secure_params=include_secure_params, **kwargs + ) + return data + + def model_post_init(self, __context) -> None: + self._client = self.connection.connect() + + def ensure_collection(self) -> None: + """Create the facts collection and payload indexes if absent. Safe to call repeatedly.""" + if not self._client.collection_exists(self.collection_name): + self._client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.dimension, distance=Distance.COSINE), + ) + for key in ("user_id", "hash"): + self._client.create_payload_index( + collection_name=self.collection_name, + field_name=key, + field_schema="keyword", + ) + + def recreate_collection(self) -> None: + """Drop and re-create the facts collection. Test-only helper.""" + if self._client.collection_exists(self.collection_name): + self._client.delete_collection(self.collection_name) + self.ensure_collection() + + def drop_collection(self) -> None: + """Drop the facts collection if it exists. Test-only helper.""" + if self._client.collection_exists(self.collection_name): + self._client.delete_collection(self.collection_name) + + def insert(self, fact: Fact, embedding: list[float]) -> None: + self._client.upsert( + collection_name=self.collection_name, + points=[ + PointStruct( + id=_to_point_id(fact.id), + vector=list(embedding), + payload=_fact_to_payload(fact), + ) + ], + ) + + def get(self, fact_id: str) -> Fact | None: + results = self._client.retrieve( + collection_name=self.collection_name, + ids=[_to_point_id(fact_id)], + with_payload=True, + with_vectors=False, + ) + if not results: + return None + return _payload_to_fact(results[0].payload) + + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + points, _ = self._client.scroll( + collection_name=self.collection_name, + scroll_filter=Filter( + must=[ + FieldCondition(key="user_id", match=MatchValue(value=user_id)), + FieldCondition(key="hash", match=MatchValue(value=content_hash)), + ] + ), + limit=1, + with_payload=True, + ) + if not points: + return None + return _payload_to_fact(points[0].payload) + + def delete(self, fact_id: str) -> None: + self._client.delete( + collection_name=self.collection_name, + points_selector=PointIdsList(points=[_to_point_id(fact_id)]), + ) + + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict, + updated_at: datetime, + ) -> None: + existing = self.get(fact_id) + if existing is None: + return + new_fact = existing.model_copy( + update={"content": content, "hash": content_hash, "metadata": metadata, "updated_at": updated_at} + ) + self._client.upsert( + collection_name=self.collection_name, + points=[ + PointStruct( + id=_to_point_id(fact_id), + vector=list(embedding), + payload=_fact_to_payload(new_fact), + ) + ], + ) + + def search( + self, + *, + query_embedding: list[float], + scope: dict[str, str], + limit: int, + ) -> list[tuple[Fact, float]]: + results = self._client.search( + collection_name=self.collection_name, + query_vector=list(query_embedding), + query_filter=_scope_to_filter(scope), + limit=limit, + with_payload=True, + with_vectors=False, + ) + return [(_payload_to_fact(point.payload), float(point.score)) for point in results] + + def list_by_scope(self, scope: dict[str, str], limit: int = 100) -> list[Fact]: + points, _ = self._client.scroll( + collection_name=self.collection_name, + scroll_filter=_scope_to_filter(scope), + limit=limit, + with_payload=True, + ) + return [_payload_to_fact(p.payload) for p in points] + + def delete_scope(self, scope: dict[str, str]) -> int: + # Scroll for all matching point ids (paginated, no 10k cap), then delete + # them in one call. Compared to count-then-delete-by-filter this trades + # an extra round-trip for an accurate count of what we actually removed + # — the count+delete variant could diverge under concurrent writes. + # Empty scope = "match everything" — same contract as the in-memory + # and pgvector backends — so we use an empty Filter() rather than None. + scope_filter = _scope_to_filter(scope) or Filter() + ids: list = [] + offset = None + while True: + points, offset = self._client.scroll( + collection_name=self.collection_name, + scroll_filter=scope_filter, + limit=1000, + offset=offset, + with_payload=False, + with_vectors=False, + ) + ids.extend(p.id for p in points) + if offset is None: + break + if not ids: + return 0 + self._client.delete( + collection_name=self.collection_name, + points_selector=PointIdsList(points=ids), + ) + return len(ids) diff --git a/dynamiq/memory/long_term/backends/weaviate.py b/dynamiq/memory/long_term/backends/weaviate.py new file mode 100644 index 000000000..50383ae07 --- /dev/null +++ b/dynamiq/memory/long_term/backends/weaviate.py @@ -0,0 +1,255 @@ +import json +import uuid +from datetime import datetime +from typing import TYPE_CHECKING, Any + +from pydantic import ConfigDict, Field, PrivateAttr + +from dynamiq.connections import Weaviate as WeaviateConnection +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +if TYPE_CHECKING: + from weaviate import WeaviateClient + +# Weaviate properties are strictly typed and rejects nested objects, so we +# JSON-encode the Fact's `metadata` dict into a single TEXT property. +_METADATA_JSON_KEY = "metadata_json" +# Deterministic namespace so two backends pointing at the same collection +# resolve a given `fact_id` to the same UUID — required for delete/update +# round-trips when the original fact_id is not itself a UUID. +_UUID_NAMESPACE = uuid.UUID("00000000-0000-0000-0000-000000000000") + + +def _to_weaviate_uuid(fact_id: str) -> str: + return str(uuid.uuid5(_UUID_NAMESPACE, fact_id)) + + +def _fact_to_properties(fact: Fact) -> dict[str, Any]: + return { + "fact_id": fact.id, + "content": fact.content, + "hash": fact.hash, + "user_id": fact.user_id, + _METADATA_JSON_KEY: json.dumps(fact.metadata or {}), + "created_at": fact.created_at.isoformat(), + "updated_at": fact.updated_at.isoformat(), + } + + +def _properties_to_fact(props: dict[str, Any]) -> Fact: + raw_meta = props.get(_METADATA_JSON_KEY) or "{}" + return Fact( + id=props["fact_id"], + content=props["content"], + hash=props["hash"], + user_id=props["user_id"], + metadata=json.loads(raw_meta) if isinstance(raw_meta, str) else raw_meta, + created_at=datetime.fromisoformat(props["created_at"]), + updated_at=datetime.fromisoformat(props["updated_at"]), + ) + + +def _scope_to_filter(scope: dict[str, str]): + """Translate `{key: value, ...}` to a weaviate v4 `Filter` expression, AND-ing + multiple keys. Imported lazily so the module load doesn't require weaviate.""" + if not scope: + return None + from weaviate.classes.query import Filter + + items = list(scope.items()) + expr = Filter.by_property(items[0][0]).equal(items[0][1]) + for key, value in items[1:]: + expr = expr & Filter.by_property(key).equal(value) + return expr + + +def _id_in_filter(uuids: list[str]): + """`Filter.by_id().contains_any(...)` factored out so tests can stub it.""" + from weaviate.classes.query import Filter + + return Filter.by_id().contains_any(uuids) + + +class WeaviateLongTermMemoryBackend(LongTermMemoryBackend): + """Long-term memory backend backed by Weaviate (client v4). + + Each fact is one Weaviate object whose UUID is derived deterministically + from the original `fact_id` (UUID5 over a fixed namespace) so id-based + operations round-trip cleanly. Free-form `Fact.metadata` is JSON-encoded + into a single TEXT property to dodge Weaviate's strict-schema constraint. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "weaviate-long-term-memory-backend" + connection: WeaviateConnection = Field(default_factory=WeaviateConnection) + collection_name: str = "UserFacts" + dimension: int = 1536 + # Page size for scoped scans (list/delete). Capped at Weaviate's default + # `QUERY_MAXIMUM_RESULTS` so a single fetch never exceeds server limits. + _SCOPE_PAGE_SIZE: int = 10_000 + + _client: "WeaviateClient | None" = PrivateAttr(default=None) + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + return super().to_dict_exclude_params | {"_client": True, "connection": True} + + def to_dict(self, include_secure_params: bool = False, for_tracing: bool = False, **kwargs) -> dict[str, Any]: + data = super().to_dict(include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs) + data["connection"] = self.connection.to_dict( + for_tracing=for_tracing, include_secure_params=include_secure_params, **kwargs + ) + return data + + def model_post_init(self, __context) -> None: + # Only resolve the client here; the collection proxy is fetched lazily + # so backend construction does not depend on the collection already + # existing — callers can construct, call ensure_collection(), then use. + self._client = self.connection.connect() + + @property + def _collection(self): + """Lazy collection proxy. Re-fetched per access — the call is local to + the weaviate client (no network) and avoids stale state if the + collection is dropped/recreated between operations.""" + return self._client.collections.get(self.collection_name) + + def ensure_collection(self) -> None: + """Create the facts collection if absent. Safe to call repeatedly.""" + from weaviate.classes.config import Configure, DataType, Property + + if self._client.collections.exists(self.collection_name): + return + self._client.collections.create( + name=self.collection_name, + vectorizer_config=Configure.Vectorizer.none(), + vector_index_config=Configure.VectorIndex.hnsw( + distance_metric=Configure.VectorDistances.COSINE, + ), + properties=[ + Property(name="fact_id", data_type=DataType.TEXT), + Property(name="content", data_type=DataType.TEXT), + Property(name="hash", data_type=DataType.TEXT), + Property(name="user_id", data_type=DataType.TEXT), + Property(name=_METADATA_JSON_KEY, data_type=DataType.TEXT), + Property(name="created_at", data_type=DataType.TEXT), + Property(name="updated_at", data_type=DataType.TEXT), + ], + ) + + def recreate_collection(self) -> None: + """Drop and re-create the facts collection. Test-only helper.""" + if self._client.collections.exists(self.collection_name): + self._client.collections.delete(self.collection_name) + self.ensure_collection() + + def drop_collection(self) -> None: + """Drop the facts collection if it exists. Test-only helper.""" + if self._client.collections.exists(self.collection_name): + self._client.collections.delete(self.collection_name) + + def insert(self, fact: Fact, embedding: list[float]) -> None: + self._collection.data.insert( + uuid=_to_weaviate_uuid(fact.id), + properties=_fact_to_properties(fact), + vector=list(embedding), + ) + + def get(self, fact_id: str) -> Fact | None: + obj = self._collection.query.fetch_object_by_id(uuid=_to_weaviate_uuid(fact_id)) + if obj is None: + return None + return _properties_to_fact(obj.properties) + + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + result = self._collection.query.fetch_objects( + filters=_scope_to_filter({"user_id": user_id, "hash": content_hash}), + limit=1, + ) + objects = getattr(result, "objects", []) or [] + if not objects: + return None + return _properties_to_fact(objects[0].properties) + + def delete(self, fact_id: str) -> None: + self._collection.data.delete_by_id(uuid=_to_weaviate_uuid(fact_id)) + + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict, + updated_at: datetime, + ) -> None: + existing = self.get(fact_id) + if existing is None: + return + new_fact = existing.model_copy( + update={"content": content, "hash": content_hash, "metadata": metadata, "updated_at": updated_at} + ) + # `replace` overwrites properties + vector while preserving the uuid. + self._collection.data.replace( + uuid=_to_weaviate_uuid(fact_id), + properties=_fact_to_properties(new_fact), + vector=list(embedding), + ) + + def search( + self, + *, + query_embedding: list[float], + scope: dict[str, str], + limit: int, + ) -> list[tuple[Fact, float]]: + from weaviate.classes.query import MetadataQuery + + result = self._collection.query.near_vector( + near_vector=list(query_embedding), + limit=limit, + filters=_scope_to_filter(scope), + return_metadata=MetadataQuery(distance=True), + ) + objects = getattr(result, "objects", []) or [] + out: list[tuple[Fact, float]] = [] + for obj in objects: + distance = getattr(obj.metadata, "distance", None) if obj.metadata is not None else None + # Weaviate returns cosine *distance* in [0, 2]; convert to similarity + # so callers see the same shape as Qdrant/Pinecone (`1.0 = best`). + score = 1.0 - float(distance) if distance is not None else 0.0 + out.append((_properties_to_fact(obj.properties), score)) + return out + + def list_by_scope(self, scope: dict[str, str], limit: int = 100) -> list[Fact]: + result = self._collection.query.fetch_objects( + filters=_scope_to_filter(scope), + limit=limit, + ) + objects = getattr(result, "objects", []) or [] + return [_properties_to_fact(obj.properties) for obj in objects] + + def delete_scope(self, scope: dict[str, str]) -> int: + # Weaviate's `delete_many(where=...)` is server-capped (default ~10k per + # call) and doesn't report a count we can rely on across versions. Loop + # fetch-then-delete-by-uuid batches until the page comes back empty — + # this is unbounded and gives an accurate count of what we actually + # removed. Empty scope = match everything, same contract as Qdrant / + # in-memory; we drive that with `fetch_objects(limit=...)` (no filter). + flt = _scope_to_filter(scope) + total = 0 + while True: + if flt is None: + result = self._collection.query.fetch_objects(limit=self._SCOPE_PAGE_SIZE) + else: + result = self._collection.query.fetch_objects(filters=flt, limit=self._SCOPE_PAGE_SIZE) + objects = getattr(result, "objects", []) or [] + if not objects: + break + uuids = [str(o.uuid) for o in objects] + self._collection.data.delete_many(where=_id_in_filter(uuids)) + total += len(uuids) + return total diff --git a/dynamiq/memory/long_term/base.py b/dynamiq/memory/long_term/base.py new file mode 100644 index 000000000..9979e7b13 --- /dev/null +++ b/dynamiq/memory/long_term/base.py @@ -0,0 +1,266 @@ +from abc import ABC, abstractmethod +from datetime import UTC, datetime +from functools import cached_property +from hashlib import md5 +from typing import Any +from uuid import uuid4 + +from pydantic import BaseModel, ConfigDict, Field, computed_field + +from dynamiq.memory.long_term.schemas import Fact +from dynamiq.memory.long_term.types import ForgetStatus, RememberOutcome +from dynamiq.nodes.embedders.base import TextEmbedder, TextEmbedderInputSchema +from dynamiq.utils import generate_uuid +from dynamiq.utils.logger import logger + + +class LongTermMemoryError(Exception): + """Base exception for long-term memory operations.""" + + pass + + +def _content_hash(user_id: str, content: str) -> str: + """Per-user stable hash used only as a dedup key, never as a security primitive.""" + normalised = content.strip().lower() + return md5(f"{user_id}:{normalised}".encode(), usedforsecurity=False).hexdigest() + + +class LongTermMemoryBackend(ABC, BaseModel): + """Fact-shaped, user-scoped storage + embedding engine for long-term memory. + + Subclasses implement the abstract storage primitives (`insert`, `get`, + `search`, `update`, ...). The high-level operations the agent tools call — + `remember`, `recall`, `forget`, `list_all`, `clear_user` — are concrete here + (Template Method): they orchestrate embedding, dedup, and semantic upsert in + terms of those primitives, so every backend gets them for free. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + name: str = "long-term-memory-backend" + id: str = Field(default_factory=generate_uuid) + embedder: TextEmbedder = Field( + ..., + description="Text embedder used to vectorize facts on write and queries on read.", + ) + upsert_threshold: float = Field( + default=0.85, + ge=0.0, + le=1.0, + description=( + "Cosine similarity above which a new `remember()` call replaces the " + "nearest existing fact in place instead of inserting a new row. " + "Set to 1.0 to disable upsert (insert-only)." + ), + ) + + @computed_field + @cached_property + def type(self) -> str: + """Fully-qualified class id used by the YAML loader for polymorphic reconstruction.""" + return f"{self.__module__.rsplit('.', 1)[0]}.{self.__class__.__name__}" + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + """Field names to exclude from serialization (overridden by subclasses).""" + return {"embedder": True} + + def to_dict(self, include_secure_params: bool = False, for_tracing: bool = False, **kwargs) -> dict[str, Any]: + """Serialize the backend to a dict for workflow YAML round-trip.""" + data = self.model_dump(exclude=kwargs.pop("exclude", self.to_dict_exclude_params), **kwargs) + data["embedder"] = self.embedder.to_dict( + include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs + ) + return data + + def _embed(self, text: str) -> list[float]: + result = self.embedder.execute(input_data=TextEmbedderInputSchema(query=text)) + return list(result["embedding"]) + + # --- high-level operations (Template Method over the storage primitives) --- + + def remember( + self, *, content: str, user_id: str, metadata: dict[str, Any] | None = None + ) -> tuple[Fact, RememberOutcome]: + """Add or upsert a fact for `user_id`. Returns the fact and a `RememberOutcome`. + + 1. Exact-duplicate guard: if `(user_id, normalised content)` already exists, + return it with `UNCHANGED` (no embed cost). + 2. Otherwise embed once and search the user's facts for the nearest neighbour. + If the top match's cosine score exceeds `upsert_threshold`, replace that + fact's content/hash/embedding/metadata in place (preserving id, created_at) + and return `UPDATED`. This is how an agent "corrects" a fact: re-state it. + 3. Otherwise insert a brand-new fact and return `CREATED`. + + Raises: + LongTermMemoryError: If content is empty or storage fails. + """ + if not content or not content.strip(): + raise LongTermMemoryError("Fact content cannot be empty") + try: + normalised = content.strip() + content_hash = _content_hash(user_id, normalised) + + existing = self.get_by_hash(user_id=user_id, content_hash=content_hash) + if existing is not None: + logger.debug(f"LongTermMemory: exact-dedup hit for user={user_id}, fact {existing.id}") + return existing, RememberOutcome.UNCHANGED + + embedding = self._embed(normalised) + + nearest = self.search(query_embedding=embedding, scope={"user_id": user_id}, limit=1) + if nearest and nearest[0][1] >= self.upsert_threshold: + old_fact, score = nearest[0] + now = datetime.now(UTC) + # New metadata replaces the old when the caller supplies it; + # otherwise the existing metadata is preserved. + new_metadata = metadata if metadata is not None else old_fact.metadata + self.update( + old_fact.id, + content=normalised, + content_hash=content_hash, + embedding=embedding, + metadata=new_metadata, + updated_at=now, + ) + logger.debug( + f"LongTermMemory: upsert hit (score={score:.3f}) — updated fact {old_fact.id} for user={user_id}" + ) + updated = old_fact.model_copy( + update={ + "content": normalised, + "hash": content_hash, + "metadata": new_metadata, + "updated_at": now, + } + ) + return updated, RememberOutcome.UPDATED + + now = datetime.now(UTC) + fact = Fact( + id=str(uuid4()), + content=normalised, + hash=content_hash, + user_id=user_id, + metadata=metadata or {}, + created_at=now, + updated_at=now, + ) + self.insert(fact, embedding) + logger.debug(f"LongTermMemory: stored fact {fact.id} for user={user_id}") + return fact, RememberOutcome.CREATED + except Exception as e: + logger.error(f"LongTermMemory.remember failed for user={user_id}: {e}") + raise LongTermMemoryError(f"Failed to remember fact: {e}") from e + + def recall(self, *, query: str, user_id: str, limit: int = 5) -> list[tuple[Fact, float]]: + """Semantic search for facts relevant to `query`, scoped to `user_id`. + + Raises: + LongTermMemoryError: If the query is empty or search fails. + """ + stripped = query.strip() if query else "" + if not stripped: + raise LongTermMemoryError("Recall query cannot be empty") + try: + embedding = self._embed(stripped) + results = self.search(query_embedding=embedding, scope={"user_id": user_id}, limit=limit) + logger.debug(f"LongTermMemory: recall for user={user_id} returned {len(results)} facts") + return results + except Exception as e: + logger.error(f"LongTermMemory.recall failed for user={user_id}: {e}") + raise LongTermMemoryError(f"Failed to recall facts: {e}") from e + + def forget(self, *, fact_id: str, user_id: str) -> ForgetStatus: + """Delete a fact by id, returning a `ForgetStatus`. Never raises on user mismatch. + + Raises: + LongTermMemoryError: If the storage delete fails for any other reason. + """ + try: + fact = self.get(fact_id) + if fact is None: + return ForgetStatus.NOT_FOUND + if fact.user_id != user_id: + logger.warning( + f"LongTermMemory.forget: cross-user delete blocked " + f"(owner={fact.user_id}, caller={user_id}, fact={fact_id})" + ) + return ForgetStatus.FORBIDDEN + self.delete(fact_id) + logger.debug(f"LongTermMemory: deleted fact {fact_id} for user={user_id}") + return ForgetStatus.DELETED + except Exception as e: + logger.error(f"LongTermMemory.forget failed for fact={fact_id}, user={user_id}: {e}") + raise LongTermMemoryError(f"Failed to forget fact: {e}") from e + + def list_all(self, *, user_id: str, limit: int = 100) -> list[Fact]: + """Return up to `limit` facts for `user_id` (admin/introspection).""" + try: + return self.list_by_scope({"user_id": user_id}, limit=limit) + except Exception as e: + logger.error(f"LongTermMemory.list_all failed for user={user_id}: {e}") + raise LongTermMemoryError(f"Failed to list facts: {e}") from e + + def clear_user(self, *, user_id: str) -> int: + """Hard-delete every fact owned by `user_id` and return the count deleted.""" + try: + deleted = self.delete_scope({"user_id": user_id}) + logger.debug(f"LongTermMemory: cleared {deleted} facts for user={user_id}") + return deleted + except Exception as e: + logger.error(f"LongTermMemory.clear_user failed for user={user_id}: {e}") + raise LongTermMemoryError(f"Failed to clear user facts: {e}") from e + + # --- storage primitives (implemented per backend) --- + + @abstractmethod + def insert(self, fact: Fact, embedding: list[float]) -> None: + """Insert a new fact and its embedding. Caller has already deduped via `get_by_hash`.""" + + @abstractmethod + def get(self, fact_id: str) -> Fact | None: + """Fetch a fact by id, or `None` if it does not exist.""" + + @abstractmethod + def get_by_hash(self, *, user_id: str, content_hash: str) -> Fact | None: + """Fetch the fact matching `(user_id, content_hash)`, or `None`.""" + + @abstractmethod + def delete(self, fact_id: str) -> None: + """Hard-delete a single fact. No-op if not present.""" + + @abstractmethod + def search( + self, *, query_embedding: list[float], + scope: dict[str, str], limit: int, + ) -> list[tuple[Fact, float]]: + """Return up to `limit` `(fact, score)` tuples matching `scope`, most relevant first.""" + + @abstractmethod + def list_by_scope( + self, scope: dict[str, str], limit: int = 100, + ) -> list[Fact]: + """Return up to `limit` facts matching `scope`, non-semantically.""" + + @abstractmethod + def delete_scope(self, scope: dict[str, str]) -> int: + """Hard-delete every fact matching `scope` and return the count deleted.""" + + @abstractmethod + def update( + self, + fact_id: str, + *, + content: str, + content_hash: str, + embedding: list[float], + metadata: dict[str, Any], + updated_at: datetime, + ) -> None: + """Replace content/hash/embedding/metadata/updated_at for a fact in place. + + Preserves `id`, `user_id`, and `created_at`. Used by the semantic-upsert + path in `remember`. + """ diff --git a/dynamiq/memory/long_term/long_term_memory.py b/dynamiq/memory/long_term/long_term_memory.py new file mode 100644 index 000000000..a0e956e7e --- /dev/null +++ b/dynamiq/memory/long_term/long_term_memory.py @@ -0,0 +1,54 @@ +from functools import cached_property +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, computed_field, field_serializer + +from dynamiq.memory.long_term.base import LongTermMemoryBackend +from dynamiq.memory.long_term.types import MemoryToolKind + + +class LongTermMemoryConfig(BaseModel): + """Agent-level configuration for long-term memory. + + Mirrors `SandboxConfig` / `SkillsConfig`: an on/off switch plus the backend + that does the work, plus which memory tools to expose to the LLM. All + operations (remember/recall/forget) live on `backend`. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + enabled: bool = True + backend: LongTermMemoryBackend = Field( + ..., + description="Backend engine that stores facts, embeds text, and serves remember/recall/forget.", + ) + tools: tuple[MemoryToolKind, ...] = Field( + default=(MemoryToolKind.REMEMBER, MemoryToolKind.RECALL), + description="Which long-term-memory tools to expose to the agent's LLM.", + ) + + @computed_field + @cached_property + def type(self) -> str: + """Fully-qualified class id used by the YAML loader for reconstruction.""" + return f"{self.__module__.rsplit('.', 1)[0]}.{self.__class__.__name__}" + + @property + def to_dict_exclude_params(self) -> dict[str, bool]: + """Fields excluded from default model_dump; re-added by `to_dict`.""" + return {"backend": True} + + def to_dict(self, include_secure_params: bool = False, **kwargs) -> dict[str, Any]: + """Serialize so the backend round-trips via its own `to_dict`.""" + for_tracing = kwargs.pop("for_tracing", False) + data = self.model_dump(exclude=kwargs.pop("exclude", self.to_dict_exclude_params), **kwargs) + data["backend"] = self.backend.to_dict( + include_secure_params=include_secure_params, for_tracing=for_tracing, **kwargs + ) + return data + + @field_serializer("tools") + def _serialize_tools(self, tools: tuple[MemoryToolKind, ...]) -> tuple[str, ...]: + # Emit plain string values so YAML round-trip and tracing work; pydantic + # default-mode dump returns enum members which yaml.safe_dump cannot render. + return tuple(t.value for t in tools) diff --git a/dynamiq/memory/long_term/schemas.py b/dynamiq/memory/long_term/schemas.py new file mode 100644 index 000000000..6b239f2f8 --- /dev/null +++ b/dynamiq/memory/long_term/schemas.py @@ -0,0 +1,16 @@ +from datetime import datetime +from typing import Any + +from pydantic import BaseModel, Field + + +class Fact(BaseModel): + """A single long-term memory fact, scoped to a user.""" + + id: str + content: str + hash: str + user_id: str + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime + updated_at: datetime diff --git a/dynamiq/memory/long_term/types.py b/dynamiq/memory/long_term/types.py new file mode 100644 index 000000000..95e1ae74b --- /dev/null +++ b/dynamiq/memory/long_term/types.py @@ -0,0 +1,24 @@ +from enum import Enum + + +class ForgetStatus(str, Enum): + """Outcome of `LongTermMemoryBackend.forget()` (programmatic API only).""" + + DELETED = "deleted" + NOT_FOUND = "not_found" + FORBIDDEN = "forbidden" + + +class RememberOutcome(str, Enum): + """Outcome of `LongTermMemoryBackend.remember()` — distinguishes insert from upsert.""" + + CREATED = "created" + UPDATED = "updated" + UNCHANGED = "unchanged" + + +class MemoryToolKind(str, Enum): + """Kinds of long-term-memory tools exposed to an agent.""" + + REMEMBER = "remember" + RECALL = "recall" diff --git a/dynamiq/nodes/agents/agent.py b/dynamiq/nodes/agents/agent.py index 50eef5530..f771fa1f0 100644 --- a/dynamiq/nodes/agents/agent.py +++ b/dynamiq/nodes/agents/agent.py @@ -8,6 +8,7 @@ from dynamiq.callbacks import AgentStreamingParserCallback, StreamingQueueCallbackHandler from dynamiq.executors.context import ContextAwareThreadPoolExecutor from dynamiq.nodes.agents.base import Agent as BaseAgent +from dynamiq.nodes.agents.base import _run_extra_tools from dynamiq.nodes.agents.components import parser, schema_generator from dynamiq.nodes.agents.components.history_manager import HistoryManagerMixin from dynamiq.nodes.agents.exceptions import ( @@ -972,11 +973,15 @@ def _setup_prompt_and_stop_sequences( input_message: The user's input message history_messages: Optional conversation history """ + # Pass overlay-aware tool variables so per-call LTM tools appear in the + # system prompt (relevant for XML/ReAct mode, where the model learns + # about tools from the prompt rather than function-calling schemas). system_message = Message( role=MessageRole.SYSTEM, content=self.generate_prompt( tools_name=self.tool_names, - input_formats=schema_generator.generate_input_formats(self.tools, self.sanitize_tool_name), + tool_description=self.tool_description, + input_formats=schema_generator.generate_input_formats(self._runtime_tools, self.sanitize_tool_name), ), static=True, ) @@ -1504,10 +1509,11 @@ def _run_react_llm_step(self, config: RunnableConfig | None, loop_num: int, **kw try: native_parallel = self.parallel_tool_calls_enabled and self.inference_mode == InferenceMode.FUNCTION_CALLING + fc_tools, response_format = self._effective_inference_schemas() llm_result = self._run_llm( messages=messages, - tools=self._tools, - response_format=self._response_format, + tools=fc_tools, + response_format=response_format, config=llm_config, parallel_tool_calls=True if native_parallel else None, **kwargs, @@ -1914,6 +1920,36 @@ def _refresh_agent_state(self, loop_num: int) -> None: except Exception as e: logger.debug("Failed to load todo state (none or invalid): %s", e) + def _build_inference_schemas(self, tools: list) -> tuple: + """Build (function_calling_tools, response_format) for the given tool list. + + Returns the init-time defaults for modes that don't apply, so callers + can substitute whichever value the current inference mode produces. + """ + fc_tools = self._tools + response_format = self._response_format + if self.inference_mode == InferenceMode.FUNCTION_CALLING: + fc_tools = schema_generator.generate_function_calling_schemas( + tools, + self.delegation_allowed, + self.sanitize_tool_name, + response_format=self.response_format, + ) + elif self.inference_mode == InferenceMode.STRUCTURED_OUTPUT: + response_format = schema_generator.generate_structured_output_schemas( + tools, self.sanitize_tool_name, self.delegation_allowed + ) + return fc_tools, response_format + + def _effective_inference_schemas(self) -> tuple: + """Inference schemas for the current call, including any per-call LTM + overlay. When no overlay is set this is the init-time cache; when LTM + tools are attached they're regenerated so remember/recall are visible + to the LLM in FUNCTION_CALLING and STRUCTURED_OUTPUT modes.""" + if not _run_extra_tools.get(): + return self._tools, self._response_format + return self._build_inference_schemas(self._runtime_tools) + def _init_prompt_blocks(self): """Initialize the prompt blocks required for the ReAct strategy.""" # Generate inference-mode schemas @@ -1940,10 +1976,17 @@ def _init_prompt_blocks(self): response_format_schema=response_format_schema, ) + # `has_tools` decides whether the XML/ReAct template emits tool-related + # blocks. Long-term memory injects per-call tools that aren't visible at + # init time, so we must opt in here too — otherwise the template has no + # placeholder for them and they stay invisible to the LLM. + ltm_enabled = self.long_term_memory is not None and self.long_term_memory.enabled self.system_prompt_manager.build_react_prompt( ReactPromptConfig( inference_mode=self.inference_mode, - has_tools=bool(self.tools) or (self.skills.enabled and self.skills.source is not None), + has_tools=bool(self.tools) + or (self.skills.enabled and self.skills.source is not None) + or ltm_enabled, parallel_tool_calls_enabled=self.parallel_tool_calls_enabled, delegation_allowed=self.delegation_allowed, context_compaction_enabled=self.summarization_config.enabled, diff --git a/dynamiq/nodes/agents/base.py b/dynamiq/nodes/agents/base.py index 8997a4980..2293d5d99 100644 --- a/dynamiq/nodes/agents/base.py +++ b/dynamiq/nodes/agents/base.py @@ -1,6 +1,7 @@ import io import json import re +from contextvars import ContextVar from copy import deepcopy from enum import Enum from typing import Any, Callable, ClassVar, Union @@ -10,6 +11,7 @@ from dynamiq.connections.managers import ConnectionManager from dynamiq.memory import Memory, MemoryRetrievalStrategy, MemorySaveMode +from dynamiq.memory.long_term import LongTermMemoryConfig from dynamiq.nodes import ErrorHandling, Node, NodeGroup from dynamiq.nodes.agents.checkpoint import DEFAULT_HISTORY_OFFSET, AgentIterativeCheckpointMixin from dynamiq.nodes.agents.exceptions import AgentUnknownToolException, InvalidActionException, ToolExecutionException @@ -47,6 +49,14 @@ from dynamiq.utils.logger import logger from dynamiq.utils.utils import deep_merge +# Per-call overlay of tools visible to a single `Agent.execute` invocation — +# currently used for long-term-memory tools that bind a request's `user_id` +# at construction. ContextVar gives per-thread and per-asyncio-task isolation +# without mutating shared agent state, so concurrent execute() calls on the +# same agent instance never see each other's user-scoped tools and never +# block on a lock. +_run_extra_tools: ContextVar[list["Node"]] = ContextVar("dynamiq_agent_run_extra_tools", default=[]) + class StreamChunkChoiceDelta(BaseModel): """Delta model for content chunks.""" @@ -222,6 +232,13 @@ class Agent(AgentIterativeCheckpointMixin, Node): memory: Memory | None = Field(None, description="Memory node for the agent.") memory_limit: int = Field(100, description="Maximum number of messages to retrieve from memory") memory_retrieval_strategy: MemoryRetrievalStrategy | None = MemoryRetrievalStrategy.ALL + long_term_memory: LongTermMemoryConfig | None = Field( + default=None, + description=( + "Long-term, fact-shaped, user-scoped memory config (enabled + backend + tools). " + "Accessed via remember/recall tools. Independent of `memory` (short-term messages)." + ), + ) verbose: bool = Field(False, description="Whether to print verbose logs.") file_store: FileStoreConfig = Field( default_factory=lambda: FileStoreConfig(enabled=False, backend=InMemoryFileStore()), @@ -384,6 +401,7 @@ def to_dict_exclude_params(self): "llm": True, "tools": True, "memory": True, + "long_term_memory": True, "files": True, "images": True, "file_store": True, @@ -402,6 +420,7 @@ def to_dict(self, **kwargs) -> dict: data["tools"] = data["tools"] + [mcp_server.to_dict(**kwargs) for mcp_server in self._mcp_servers] data["memory"] = self.memory.to_dict(**kwargs) if self.memory else None + data["long_term_memory"] = self.long_term_memory.to_dict(**kwargs) if self.long_term_memory else None if self.files: data["files"] = [{"name": getattr(f, "name", f"file_{i}")} for i, f in enumerate(self.files)] if self.images: @@ -430,6 +449,12 @@ def init_components(self, connection_manager: ConnectionManager | None = None): tool.init_components(connection_manager) tool.is_optimized_for_agents = True + # The LTM embedder is a ConnectionNode that needs its text_embedder + # client built before the first recall/remember call, otherwise it + # AttributeErrors on a `None` client during `execute`. + if self.long_term_memory and self.long_term_memory.backend.embedder.is_postponed_component_init: + self.long_term_memory.backend.embedder.init_components(connection_manager) + self._ensure_skills_ingested_for_sandbox() def _ensure_skills_ingested_for_sandbox(self) -> None: @@ -618,123 +643,145 @@ def execute( use_memory = self.memory and (input_data.user_id or input_data.session_id) - if use_memory: - history_messages = self._retrieve_memory(input_data) - if len(history_messages) > 0: - history_messages.insert( - 0, - Message( - role=MessageRole.SYSTEM, - content="Below is the previous conversation history. " - "Use this context to inform your response.", - static=True, - ), - ) - else: - history_messages = None - - files = input_data.files - if files: - normalized_files = self._ensure_named_files(files) - file_paths = [] - if self.sandbox_backend: - file_paths = self._upload_files_to_sandbox(normalized_files) - else: - if not self.file_store_backend: - self._setup_in_memory_file_store_and_tools() - if self.file_store_backend: - file_paths = self._upload_files_to_file_store(normalized_files) - input_message = self._inject_attached_files_into_message( - input_message, normalized_files, file_paths=file_paths + ltm_tools = self._build_long_term_memory_tools(input_data) + if ltm_tools: + logger.info( + "Agent %s - %s: attached %d long-term memory tools (%s)", + self.name, + self.id, + len(ltm_tools), + ", ".join(t.name for t in ltm_tools), ) - - if input_data.tool_params: - kwargs["tool_params"] = input_data.tool_params - - self.system_prompt_manager.update_variables(dict(input_data)) - kwargs = kwargs | {"parent_run_id": kwargs.get("run_id")} - kwargs.pop("run_depends", None) - + # Publish the per-call LTM tools via the module-level ContextVar; the + # tool-resolution properties (`tool_description`, `tool_names`, + # `tool_by_names`) and inference-schema generation read it. Setting a + # ContextVar is cheap, isolated per thread / per asyncio task, and never + # mutates shared state — so concurrent execute() calls don't see each + # other's user-scoped tools and don't need a lock. The set() is the last + # statement before `try:` so nothing can raise between it and the + # matching reset() in finally. + ltm_token = _run_extra_tools.set(ltm_tools) if ltm_tools else None try: - result = self._run_agent(input_message, history_messages, config=config, **kwargs) - except CanceledException: if use_memory: - try: - self._save_history_to_memory(custom_metadata) - except Exception as save_error: - logger.error( - f"Agent {self.name} - {self.id}: failed to save history to memory " - f"after cancel: {save_error}", + history_messages = self._retrieve_memory(input_data) + if len(history_messages) > 0: + history_messages.insert( + 0, + Message( + role=MessageRole.SYSTEM, + content="Below is the previous conversation history. " + "Use this context to inform your response.", + static=True, + ), ) + else: + history_messages = None + + files = input_data.files + if files: + normalized_files = self._ensure_named_files(files) + file_paths = [] + if self.sandbox_backend: + file_paths = self._upload_files_to_sandbox(normalized_files) + else: + if not self.file_store_backend: + self._setup_in_memory_file_store_and_tools() + if self.file_store_backend: + file_paths = self._upload_files_to_file_store(normalized_files) + input_message = self._inject_attached_files_into_message( + input_message, normalized_files, file_paths=file_paths + ) + + if input_data.tool_params: + kwargs["tool_params"] = input_data.tool_params + + self.system_prompt_manager.update_variables(dict(input_data)) + kwargs = kwargs | {"parent_run_id": kwargs.get("run_id")} + kwargs.pop("run_depends", None) + + try: + result = self._run_agent(input_message, history_messages, config=config, **kwargs) + except CanceledException: + if use_memory: + try: + self._save_history_to_memory(custom_metadata) + except Exception as save_error: + logger.error( + f"Agent {self.name} - {self.id}: failed to save history to memory " + f"after cancel: {save_error}", + ) + try: + self._append_user_input_to_memory(custom_metadata) + except Exception as save_error2: + logger.error( + f"Agent {self.name} - {self.id}: also failed to save user input " + f"after cancel: {save_error2}", + ) + raise + except Exception: + if use_memory: try: self._append_user_input_to_memory(custom_metadata) - except Exception as save_error2: + except Exception as save_error: logger.error( - f"Agent {self.name} - {self.id}: also failed to save user input " - f"after cancel: {save_error2}", + f"Agent {self.name} - {self.id}: failed to save user input to memory " + f"after agent error: {save_error}", ) - raise - except Exception: + raise + finally: + self._current_call_context = None + self._clear_todos_file() + if use_memory: try: - self._append_user_input_to_memory(custom_metadata) + self._save_history_to_memory(custom_metadata, final_output=result) except Exception as save_error: logger.error( - f"Agent {self.name} - {self.id}: failed to save user input to memory " - f"after agent error: {save_error}", + "Agent %s - %s: failed to save history to memory: %s", + self.name, + self.id, + save_error, ) - raise - finally: - self._current_call_context = None - self._clear_todos_file() - if use_memory: - try: - self._save_history_to_memory(custom_metadata, final_output=result) - except Exception as save_error: - logger.error( - "Agent %s - %s: failed to save history to memory: %s", - self.name, - self.id, - save_error, - ) - - execution_result = { - "content": result, - } + execution_result = { + "content": result, + } - requested_paths = getattr(self, "_requested_output_files", None) + requested_paths = getattr(self, "_requested_output_files", None) - if self.file_store_backend and requested_paths: - try: - stored_files = self.file_store_backend.list_files_bytes(requested_paths) - except Exception as e: - logger.warning(f"Agent {self.name} - {self.id}: failed to collect files from file store: {e}") - stored_files = [] - if stored_files: - execution_result["files"] = stored_files - logger.info( - f"Agent {self.name} - {self.id}: " - f"returning {len(stored_files)} requested file(s) from file store" - ) + if self.file_store_backend and requested_paths: + try: + stored_files = self.file_store_backend.list_files_bytes(requested_paths) + except Exception as e: + logger.warning(f"Agent {self.name} - {self.id}: failed to collect files from file store: {e}") + stored_files = [] + if stored_files: + execution_result["files"] = stored_files + logger.info( + f"Agent {self.name} - {self.id}: " + f"returning {len(stored_files)} requested file(s) from file store" + ) - if self.sandbox_backend and requested_paths: - try: - sandbox_files = self.sandbox_backend.collect_files(file_paths=requested_paths) - except Exception as e: - logger.warning(f"Agent {self.name} - {self.id}: failed to collect files from sandbox: {e}") - sandbox_files = [] - if sandbox_files: - existing_files = execution_result.get("files", []) - execution_result["files"] = existing_files + sandbox_files - logger.info( - f"Agent {self.name} - {self.id}: " - f"returning {len(sandbox_files)} requested file(s) from sandbox" - ) + if self.sandbox_backend and requested_paths: + try: + sandbox_files = self.sandbox_backend.collect_files(file_paths=requested_paths) + except Exception as e: + logger.warning(f"Agent {self.name} - {self.id}: failed to collect files from sandbox: {e}") + sandbox_files = [] + if sandbox_files: + existing_files = execution_result.get("files", []) + execution_result["files"] = existing_files + sandbox_files + logger.info( + f"Agent {self.name} - {self.id}: " + f"returning {len(sandbox_files)} requested file(s) from sandbox" + ) - logger.info(f"Node {self.name} - {self.id}: finished with RESULT:\n{str(result)[:200]}...") + logger.info(f"Node {self.name} - {self.id}: finished with RESULT:\n{str(result)[:200]}...") - return execution_result + return execution_result + finally: + if ltm_token is not None: + _run_extra_tools.reset(ltm_token) def retrieve_conversation_history( self, @@ -798,6 +845,27 @@ def _retrieve_memory(self, input_data: AgentInputSchema) -> list[Message]: logger.info("Agent %s - %s: retrieved %d messages from memory", self.name, self.id, len(history_messages)) return history_messages + def _build_long_term_memory_tools(self, input_data: "AgentInputSchema") -> list[Node]: + """Construct per-run long-term-memory tools, or [] when LTM is off/absent or user_id is missing.""" + if self.long_term_memory is None or not self.long_term_memory.enabled: + return [] + user_id = getattr(input_data, "user_id", None) + if not user_id: + return [] + from dynamiq.nodes.tools.long_term_memory import build_long_term_memory_tools + + tools = build_long_term_memory_tools( + backend=self.long_term_memory.backend, + user_id=user_id, + include=self.long_term_memory.tools, + ) + # `init_components` set this on every tool that existed at agent build + # time; LTM tools are constructed lazily per-run and must match so the + # remember/recall outputs render as friendly strings rather than raw dicts. + for tool in tools: + tool.is_optimized_for_agents = True + return tools + def _is_input_output_trace_message(self, message: Message) -> bool: """Return True when a message is an internal ReAct/tool-trace entry.""" content = message.content.strip() @@ -1795,20 +1863,31 @@ def sandbox_backend(self) -> Sandbox | None: """Get the sandbox backend from the configuration if enabled.""" return self.sandbox.backend if self.sandbox and self.sandbox.enabled else None + @property + def _runtime_tools(self) -> list[Node]: + """Tools the LLM should see for the current call: instance tools + any + per-call overlay (e.g. long-term-memory tools bound to a request's + user_id). The overlay is read from a `ContextVar` and is isolated per + thread / per asyncio task — concurrent execute() calls never see each + other's user-scoped tools.""" + extra = _run_extra_tools.get() + return self.tools + extra if extra else self.tools + @property def tool_description(self) -> str: """Returns a description of the tools available to the agent.""" - return "\n".join([f"- {tool.name}: {tool.description.strip()}" for tool in self.tools]) if self.tools else "" + tools = self._runtime_tools + return "\n".join([f"- {tool.name}: {tool.description.strip()}" for tool in tools]) if tools else "" @property def tool_names(self) -> str: """Returns a comma-separated list of tool names available to the agent.""" - return ",".join([self.sanitize_tool_name(tool.name) for tool in self.tools]) + return ",".join([self.sanitize_tool_name(tool.name) for tool in self._runtime_tools]) @property def tool_by_names(self) -> dict[str, Node]: """Returns a dictionary mapping tool names to their corresponding Node objects.""" - return {self.sanitize_tool_name(tool.name): tool for tool in self.tools} + return {self.sanitize_tool_name(tool.name): tool for tool in self._runtime_tools} def reset_run_state(self): """Resets the agent's run state. diff --git a/dynamiq/nodes/tools/long_term_memory.py b/dynamiq/nodes/tools/long_term_memory.py new file mode 100644 index 000000000..6b8cae14f --- /dev/null +++ b/dynamiq/nodes/tools/long_term_memory.py @@ -0,0 +1,220 @@ +from typing import Any, ClassVar, Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from dynamiq.memory.long_term import LongTermMemoryBackend, MemoryToolKind, RememberOutcome +from dynamiq.nodes.node import Node, ensure_config +from dynamiq.nodes.types import NodeGroup +from dynamiq.runnables import RunnableConfig +from dynamiq.types.cancellation import check_cancellation +from dynamiq.utils.logger import logger + +REMEMBER_DESCRIPTION = """Persist a durable fact about the current user to long-term memory. + +Key capabilities: +- Survives across conversations and sessions — not just this chat +- Idempotent on identical content (re-stating the same fact is a no-op) +- Semantic upsert: re-stating a near-paraphrase REPLACES the older version in place, + which is how you correct or update what you previously remembered +- Optional structured metadata (e.g. category, source) for later filtering + +Usage strategy: +- Call when the user explicitly says "remember…", "save this…", "keep in mind for next time…" +- Call when you have learned something that will clearly matter in future sessions + (a stable preference, a constraint, a recurring context, biographical info) +- To CORRECT a previously-saved fact, just call this tool again with the corrected + statement — the older paraphrase is replaced automatically. +- Do NOT use for ephemeral turn-level state — that is what the conversation history is for +- Do NOT use to remember tool outputs, file paths, or anything tied to this run +- The fact is scoped to the current user automatically; never pass or invent a user id + +Returns: a short status line — "Fact saved.", "Fact updated.", or "Already remembered." + +Examples: +- {"content": "Prefers dogs over cats"} +- {"content": "Allergic to peanuts", "metadata": {"category": "health"}} +- {"content": "Works in EST timezone", "metadata": {"category": "context"}} +""" + +RECALL_DESCRIPTION = """Search the user's long-term memory for facts relevant to one or more queries. + +Key capabilities: +- Semantic search (not keyword) — matches meaning, paraphrases, synonyms +- Scoped to the current user automatically — never crosses users +- Multi-query: pass several phrasings in one call; results are merged and de-duplicated. + Because matches are sensitive to phrasing, supplying 2–4 angles per recall typically + improves recall over a single query without an extra round-trip. +- Returns the most relevant facts first, as plain text + +Usage strategy: +- Call PROACTIVELY at the start of a turn when the request hints at something personal + (preferences, past decisions, biographical info, recurring context) +- Call BEFORE answering questions where prior context would change your response +- Prefer 2–4 distinct phrasings over a single query when the topic is fuzzy +- If no relevant facts are found, just proceed without them +- Skip when the question is purely factual or has no user-specific component + +Returns: a bullet list of relevant facts (most relevant first), or "No relevant facts." + +Examples: +- {"queries": ["food preferences"]} +- {"queries": ["what does the user do for work?", "user profession", "user job role"], "limit": 3} +- {"queries": ["timezone", "working hours", "schedule constraints"], "limit": 10} +""" + + +class RememberFactInputSchema(BaseModel): + """LLM-visible input for `remember_fact`. `user_id` is bound at construction.""" + + content: str = Field(..., min_length=1, max_length=1000, + description="The fact to remember, as a short statement.") + metadata: dict[str, Any] | None = Field( + default=None, + description="Optional free-form metadata (e.g. {'category': 'preference'}).", + ) + + +class RecallFactsInputSchema(BaseModel): + """LLM-visible input for `recall_facts`. `user_id` is bound at construction.""" + + queries: list[str] = Field( + ..., + min_length=1, + max_length=5, + description=( + "One or more search phrasings. Semantic search is phrasing-sensitive, " + "so 2–4 distinct angles usually beat a single query for fuzzy topics." + ), + ) + limit: int = Field(default=5, ge=1, le=20, + description="Max facts to return after merging across queries.") + + @field_validator("queries", mode="after") + @classmethod + def _strip_and_require_nonblank(cls, queries: list[str]) -> list[str]: + """Reject whitespace-only entries here so the model sees a clean + validation error, instead of the backend raising at recall time.""" + cleaned = [q.strip() for q in queries] + if any(not q for q in cleaned): + raise ValueError("`queries` must not contain empty or whitespace-only strings") + return cleaned + + +class _LongTermMemoryTool(Node): + """Shared base for the long-term memory tools.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + group: Literal[NodeGroup.TOOLS] = NodeGroup.TOOLS + backend: LongTermMemoryBackend + user_id: str + + @property + def to_dict_exclude_params(self) -> dict[str, Any]: + return super().to_dict_exclude_params | {"backend": True} + + def to_dict(self, include_secure_params: bool = False, **kwargs) -> dict[str, Any]: + data = super().to_dict(include_secure_params=include_secure_params, **kwargs) + data["backend"] = self.backend.to_dict(include_secure_params=include_secure_params, **kwargs) + return data + + +_OUTCOME_MESSAGES: dict[RememberOutcome, str] = { + RememberOutcome.CREATED: "Fact saved.", + RememberOutcome.UPDATED: "Fact updated.", + RememberOutcome.UNCHANGED: "Already remembered.", +} + + +class RememberFactTool(_LongTermMemoryTool): + """Write a fact to long-term memory, scoped to the bound user_id.""" + + name: str = "remember_fact" + description: str = REMEMBER_DESCRIPTION + input_schema: ClassVar[type[RememberFactInputSchema]] = RememberFactInputSchema + + def execute( + self, input_data: RememberFactInputSchema, config: RunnableConfig | None = None, **kwargs + ) -> dict[str, Any]: + logger.debug(f"Tool {self.name} - {self.id}: started") + config = ensure_config(config) + check_cancellation(config) + self.run_on_node_execute_run(config.callbacks, **kwargs) + + fact, outcome = self.backend.remember( + content=input_data.content, + user_id=self.user_id, + metadata=input_data.metadata, + ) + if self.is_optimized_for_agents: + return {"content": _OUTCOME_MESSAGES[outcome]} + return {"content": {"fact_id": fact.id, "outcome": outcome.value}} + + +class RecallFactsTool(_LongTermMemoryTool): + """Search long-term memory for facts relevant to a query, scoped to user_id.""" + + name: str = "recall_facts" + description: str = RECALL_DESCRIPTION + input_schema: ClassVar[type[RecallFactsInputSchema]] = RecallFactsInputSchema + + def execute( + self, input_data: RecallFactsInputSchema, config: RunnableConfig | None = None, **kwargs + ) -> dict[str, Any]: + logger.debug(f"Tool {self.name} - {self.id}: started") + config = ensure_config(config) + check_cancellation(config) + self.run_on_node_execute_run(config.callbacks, **kwargs) + + # Recall once per query; merge by fact id keeping the best score so a + # paraphrase that scores higher under one phrasing isn't penalised by + # another phrasing's weaker hit. Ask each backend call for `limit` to + # let the union be re-ranked at the end. + best: dict[str, tuple[Any, float]] = {} + for query in input_data.queries: + for fact, score in self.backend.recall(query=query, user_id=self.user_id, limit=input_data.limit): + prev = best.get(fact.id) + if prev is None or score > prev[1]: + best[fact.id] = (fact, score) + + hits = sorted(best.values(), key=lambda pair: pair[1], reverse=True)[: input_data.limit] + + if self.is_optimized_for_agents: + if not hits: + return {"content": "No relevant facts."} + return {"content": "\n".join(f"- {fact.content}" for fact, _ in hits)} + return {"content": [{"content": fact.content, "score": round(score, 4)} for fact, score in hits]} + + +_TOOL_BUILDERS: dict[MemoryToolKind, type[_LongTermMemoryTool]] = { + MemoryToolKind.REMEMBER: RememberFactTool, + MemoryToolKind.RECALL: RecallFactsTool, +} + + +def build_long_term_memory_tools( + *, + backend: LongTermMemoryBackend, + user_id: str, + include: tuple[MemoryToolKind | str, ...] = ( + MemoryToolKind.REMEMBER, + MemoryToolKind.RECALL, + ), +) -> list[Node]: + """Construct long-term-memory tools with `user_id` baked in. Unknown keys in `include` are ignored. + + Skips both invalid kind strings (ValueError on enum coercion) and valid + enum members without a corresponding builder (e.g. an enum value added + here but not yet wired into `_TOOL_BUILDERS`). + """ + tools: list[Node] = [] + for kind in include: + try: + tool_kind = MemoryToolKind(kind) + except ValueError: + continue + builder = _TOOL_BUILDERS.get(tool_kind) + if builder is None: + continue + tools.append(builder(backend=backend, user_id=user_id)) + return tools diff --git a/tests/integration_with_creds/memory/conftest.py b/tests/integration_with_creds/memory/conftest.py new file mode 100644 index 000000000..2f86d3677 --- /dev/null +++ b/tests/integration_with_creds/memory/conftest.py @@ -0,0 +1,41 @@ +import hashlib +from typing import ClassVar + +import pytest + +from dynamiq.connections import BaseConnection +from dynamiq.nodes.embedders.base import TextEmbedder, TextEmbedderInputSchema + + +class _StubConnection(BaseConnection): + """No-op connection to satisfy ConnectionNode's connection/client validator.""" + + def connect(self) -> None: + return None + + +class FakeTextEmbedder(TextEmbedder): + """Deterministic 16-dim embedder for integration tests against real backends.""" + + name: str = "fake-text-embedder" + connection: BaseConnection = _StubConnection() + DIM: ClassVar[int] = 16 + + def execute(self, input_data: TextEmbedderInputSchema, config=None, **kwargs) -> dict: + text = input_data.query if hasattr(input_data, "query") else input_data["query"] + return {"query": text, "embedding": self._embed(text)} + + def embed(self, text: str) -> list[float]: + return self._embed(text) + + @classmethod + def _embed(cls, text: str) -> list[float]: + digest = hashlib.sha256(text.encode("utf-8")).digest() + raw = [(b / 127.5) - 1.0 for b in digest[: cls.DIM]] + norm = sum(x * x for x in raw) ** 0.5 or 1.0 + return [x / norm for x in raw] + + +@pytest.fixture +def fake_embedder() -> FakeTextEmbedder: + return FakeTextEmbedder() diff --git a/tests/integration_with_creds/memory/test_pgvector_fact_backend.py b/tests/integration_with_creds/memory/test_pgvector_fact_backend.py new file mode 100644 index 000000000..f17fbdb66 --- /dev/null +++ b/tests/integration_with_creds/memory/test_pgvector_fact_backend.py @@ -0,0 +1,152 @@ +import os +from datetime import UTC, datetime +from urllib.parse import urlparse + +import pytest + +from dynamiq.connections import PostgreSQL as PostgreSQLConnection +from dynamiq.memory.long_term.backends.pgvector import PostgresLongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +DSN = os.getenv("POSTGRES_DSN") +pytestmark = pytest.mark.skipif(DSN is None, reason="POSTGRES_DSN not set") + + +def _connection_from_dsn(dsn: str) -> PostgreSQLConnection: + parsed = urlparse(dsn) + return PostgreSQLConnection( + host=parsed.hostname or "localhost", + port=parsed.port or 5432, + database=(parsed.path or "/postgres").lstrip("/"), + user=parsed.username or "postgres", + password=parsed.password or "", + ) + + +@pytest.fixture +def backend(fake_embedder): + b = PostgresLongTermMemoryBackend( + connection=_connection_from_dsn(DSN), + embedder=fake_embedder, + table_name="test_user_facts", + dimension=16, + ) + b.recreate_table() + yield b + b.drop_table() + + +def _fact(fact_id, user_id, content, content_hash=None): + now = datetime.now(UTC) + return Fact( + id=fact_id, + content=content, + hash=content_hash or f"h-{fact_id}", + user_id=user_id, + metadata={}, + created_at=now, + updated_at=now, + ) + + +# --- insert / get / get_by_hash --- + + +def test_pgvector_insert_then_get(backend, fake_embedder): + fact = _fact("f1", "u1", "hello") + backend.insert(fact, fake_embedder.embed("hello")) + fetched = backend.get("f1") + assert fetched.id == "f1" + assert fetched.content == "hello" + assert fetched.user_id == "u1" + + +def test_pgvector_get_unknown_returns_none(backend): + assert backend.get("does-not-exist") is None + + +def test_pgvector_get_by_hash(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + found = backend.get_by_hash(user_id="u1", content_hash="h-shared") + assert found is not None and found.id == "f1" + + +def test_pgvector_get_by_hash_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + assert backend.get_by_hash(user_id="u2", content_hash="h-shared") is None + + +def test_pgvector_metadata_round_trip(backend, fake_embedder): + fact = _fact("f1", "u1", "x") + fact = fact.model_copy(update={"metadata": {"category": "preference", "score": 0.8}}) + backend.insert(fact, fake_embedder.embed("x")) + fetched = backend.get("f1") + assert fetched.metadata == {"category": "preference", "score": 0.8} + + +# --- delete / list_by_scope / delete_scope --- + + +def test_pgvector_delete(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x"), fake_embedder.embed("x")) + backend.delete("f1") + assert backend.get("f1") is None + + +def test_pgvector_delete_unknown_is_noop(backend): + backend.delete("does-not-exist") + + +def test_pgvector_list_by_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + listed = backend.list_by_scope({"user_id": "u1"}) + assert {f.id for f in listed} == {"f1", "f2"} + + +def test_pgvector_delete_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + deleted = backend.delete_scope({"user_id": "u1"}) + assert deleted == 2 + assert backend.list_by_scope({"user_id": "u1"}) == [] + assert len(backend.list_by_scope({"user_id": "u2"})) == 1 + + +# --- search --- + + +def test_pgvector_search_relevance_ordered(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u1", "alpha-2"), fake_embedder.embed("alpha-2")) + backend.insert(_fact("f3", "u1", "zulu"), fake_embedder.embed("zulu")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=3, + ) + assert hits[0][0].id == "f1" + scores = [s for _, s in hits] + assert scores == sorted(scores, reverse=True) + + +def test_pgvector_search_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u2", "alpha"), fake_embedder.embed("alpha")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=5, + ) + assert [f.id for f, _ in hits] == ["f1"] + + +def test_pgvector_search_empty_returns_empty(backend, fake_embedder): + hits = backend.search( + query_embedding=fake_embedder.embed("x"), + scope={"user_id": "u1"}, + limit=5, + ) + assert hits == [] diff --git a/tests/integration_with_creds/memory/test_qdrant_fact_backend.py b/tests/integration_with_creds/memory/test_qdrant_fact_backend.py new file mode 100644 index 000000000..abe111846 --- /dev/null +++ b/tests/integration_with_creds/memory/test_qdrant_fact_backend.py @@ -0,0 +1,149 @@ +import os +from datetime import UTC, datetime + +import pytest + +QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333") + +try: # pragma: no cover - environment probe + import requests as _requests + + _requests.get(QDRANT_URL, timeout=1) + QDRANT_AVAILABLE = True +except Exception: + QDRANT_AVAILABLE = False + +pytestmark = pytest.mark.skipif( + not QDRANT_AVAILABLE, reason=f"Qdrant not reachable at {QDRANT_URL}" +) + + +from dynamiq.connections import Qdrant as QdrantConnection # noqa: E402 +from dynamiq.memory.long_term.backends.qdrant import QdrantLongTermMemoryBackend # noqa: E402 +from dynamiq.memory.long_term.schemas import Fact # noqa: E402 + + +@pytest.fixture +def backend(fake_embedder): + b = QdrantLongTermMemoryBackend( + connection=QdrantConnection(url=QDRANT_URL, api_key=""), + embedder=fake_embedder, + collection_name="test_user_facts", + dimension=16, + ) + b.recreate_collection() + yield b + b.drop_collection() + + +def _fact(fact_id, user_id, content, content_hash=None): + now = datetime.now(UTC) + return Fact( + id=fact_id, + content=content, + hash=content_hash or f"h-{fact_id}", + user_id=user_id, + metadata={}, + created_at=now, + updated_at=now, + ) + + +# --- insert / get / get_by_hash --- + + +def test_qdrant_insert_then_get(backend, fake_embedder): + fact = _fact("f1", "u1", "hello") + backend.insert(fact, fake_embedder.embed("hello")) + fetched = backend.get("f1") + assert fetched is not None + assert fetched.id == "f1" + assert fetched.content == "hello" + + +def test_qdrant_get_unknown_returns_none(backend): + assert backend.get("does-not-exist") is None + + +def test_qdrant_get_by_hash(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + found = backend.get_by_hash(user_id="u1", content_hash="h-shared") + assert found is not None and found.id == "f1" + + +def test_qdrant_get_by_hash_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + assert backend.get_by_hash(user_id="u2", content_hash="h-shared") is None + + +def test_qdrant_metadata_round_trip(backend, fake_embedder): + fact = _fact("f1", "u1", "x").model_copy( + update={"metadata": {"category": "preference", "score": 0.8}} + ) + backend.insert(fact, fake_embedder.embed("x")) + fetched = backend.get("f1") + assert fetched.metadata == {"category": "preference", "score": 0.8} + + +# --- delete / list_by_scope / delete_scope --- + + +def test_qdrant_delete(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x"), fake_embedder.embed("x")) + backend.delete("f1") + assert backend.get("f1") is None + + +def test_qdrant_list_by_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + listed = backend.list_by_scope({"user_id": "u1"}) + assert {f.id for f in listed} == {"f1", "f2"} + + +def test_qdrant_delete_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + deleted = backend.delete_scope({"user_id": "u1"}) + assert deleted == 2 + assert backend.list_by_scope({"user_id": "u1"}) == [] + assert len(backend.list_by_scope({"user_id": "u2"})) == 1 + + +# --- search --- + + +def test_qdrant_search_relevance_ordered(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u1", "alpha-2"), fake_embedder.embed("alpha-2")) + backend.insert(_fact("f3", "u1", "zulu"), fake_embedder.embed("zulu")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=3, + ) + assert hits[0][0].id == "f1" + scores = [s for _, s in hits] + assert scores == sorted(scores, reverse=True) + + +def test_qdrant_search_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u2", "alpha"), fake_embedder.embed("alpha")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=5, + ) + assert [f.id for f, _ in hits] == ["f1"] + + +def test_qdrant_search_empty_returns_empty(backend, fake_embedder): + hits = backend.search( + query_embedding=fake_embedder.embed("x"), + scope={"user_id": "u1"}, + limit=5, + ) + assert hits == [] diff --git a/tests/unit/memory/long_term/__init__.py b/tests/unit/memory/long_term/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/memory/long_term/conftest.py b/tests/unit/memory/long_term/conftest.py new file mode 100644 index 000000000..ae6019452 --- /dev/null +++ b/tests/unit/memory/long_term/conftest.py @@ -0,0 +1,67 @@ +"""Shared fixtures for long-term memory unit tests.""" +import hashlib +from typing import ClassVar + +import pytest + +from dynamiq.connections import BaseConnection +from dynamiq.nodes.embedders.base import TextEmbedder, TextEmbedderInputSchema + + +class _StubConnection(BaseConnection): + """No-op connection used only to satisfy ConnectionNode's connection/client validator.""" + + def connect(self) -> None: + return None + + +class FakeTextEmbedder(TextEmbedder): + """Deterministic `TextEmbedder` subclass for tests. + + Maps text to a fixed-length unit vector derived from its sha256 digest. + Same text → same vector. Different texts → near-orthogonal vectors + (good enough for cosine ranking in unit tests). Bypasses any real + `text_embedder` component. + """ + + name: str = "fake-text-embedder" + connection: BaseConnection = _StubConnection() + DIM: ClassVar[int] = 16 + + def execute(self, input_data: TextEmbedderInputSchema, config=None, **kwargs) -> dict: + text = input_data.query if hasattr(input_data, "query") else input_data["query"] + return {"query": text, "embedding": self._embed(text)} + + def embed(self, text: str) -> list[float]: + """Convenience helper for tests that want a raw vector without an InputSchema.""" + return self._embed(text) + + @classmethod + def _embed(cls, text: str) -> list[float]: + digest = hashlib.sha256(text.encode("utf-8")).digest() + raw = [(b / 127.5) - 1.0 for b in digest[: cls.DIM]] + norm = sum(x * x for x in raw) ** 0.5 or 1.0 + return [x / norm for x in raw] + + +@pytest.fixture +def fake_embedder() -> FakeTextEmbedder: + return FakeTextEmbedder() + + +@pytest.fixture +def backend(fake_embedder): + """A fresh in-memory backend wired with the deterministic fake embedder.""" + from dynamiq.memory.long_term.backends.in_memory import InMemoryLongTermMemoryBackend + + return InMemoryLongTermMemoryBackend(embedder=fake_embedder) + + +@pytest.fixture +def user_id() -> str: + return "user-test-123" + + +@pytest.fixture +def other_user_id() -> str: + return "user-other-456" diff --git a/tests/unit/memory/long_term/test_base.py b/tests/unit/memory/long_term/test_base.py new file mode 100644 index 000000000..5240525f8 --- /dev/null +++ b/tests/unit/memory/long_term/test_base.py @@ -0,0 +1,24 @@ +import pytest + +from dynamiq.memory.long_term.base import LongTermMemoryBackend + + +def test_long_term_memory_backend_is_abstract(): + with pytest.raises(TypeError): + LongTermMemoryBackend() + + +def test_long_term_memory_backend_update_is_abstract(): + """Subclasses must implement `update` — semantic upsert depends on it.""" + + class MissingUpdate(LongTermMemoryBackend): + def insert(self, fact, embedding): ... + def get(self, fact_id): return None + def get_by_hash(self, *, user_id, content_hash): return None + def delete(self, fact_id): ... + def search(self, *, query_embedding, scope, limit): return [] + def list_by_scope(self, scope, limit=100): return [] + def delete_scope(self, scope): return 0 + + with pytest.raises(TypeError, match="abstract"): + MissingUpdate() diff --git a/tests/unit/memory/long_term/test_in_memory_backend.py b/tests/unit/memory/long_term/test_in_memory_backend.py new file mode 100644 index 000000000..76eda682b --- /dev/null +++ b/tests/unit/memory/long_term/test_in_memory_backend.py @@ -0,0 +1,178 @@ +"""Tests for InMemoryLongTermMemoryBackend storage primitives.""" +from datetime import UTC, datetime, timedelta + +from dynamiq.memory.long_term.schemas import Fact + + +def _fact(fact_id: str, user_id: str, content: str, + content_hash: str | None = None) -> Fact: + now = datetime.now(UTC) + return Fact( + id=fact_id, content=content, + hash=content_hash or f"h-{fact_id}", + user_id=user_id, metadata={}, + created_at=now, updated_at=now, + ) + + +# --- insert / get / get_by_hash --- + +def test_insert_then_get(backend, fake_embedder): + fact = _fact("f1", "u1", "hello") + backend.insert(fact, fake_embedder.embed("hello")) + assert backend.get("f1") == fact + + +def test_get_unknown_returns_none(backend): + assert backend.get("does-not-exist") is None + + +def test_get_by_hash_returns_match(backend, fake_embedder): + fact = _fact("f1", "u1", "hello", content_hash="h-shared") + backend.insert(fact, fake_embedder.embed("hello")) + assert backend.get_by_hash(user_id="u1", content_hash="h-shared") == fact + + +def test_get_by_hash_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "hello", "h-shared"), fake_embedder.embed("hello")) + assert backend.get_by_hash(user_id="u2", content_hash="h-shared") is None + + +def test_get_by_hash_unknown_returns_none(backend): + assert backend.get_by_hash(user_id="u1", content_hash="nope") is None + + +# --- search --- + +def test_search_returns_relevance_ordered(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u1", "alpha-2"), fake_embedder.embed("alpha-2")) + backend.insert(_fact("f3", "u1", "zulu"), fake_embedder.embed("zulu")) + + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=3, + ) + assert hits[0][0].id == "f1" + scores = [score for _, score in hits] + assert scores == sorted(scores, reverse=True) + + +def test_search_filters_by_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u2", "alpha"), fake_embedder.embed("alpha")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=10, + ) + assert [f.id for f, _ in hits] == ["f1"] + + +def test_search_respects_limit(backend, fake_embedder): + for i in range(5): + backend.insert(_fact(f"f{i}", "u1", f"text{i}"), fake_embedder.embed(f"text{i}")) + hits = backend.search( + query_embedding=fake_embedder.embed("text0"), + scope={"user_id": "u1"}, limit=2, + ) + assert len(hits) == 2 + + +def test_search_empty_store_returns_empty(backend, fake_embedder): + hits = backend.search( + query_embedding=fake_embedder.embed("anything"), + scope={"user_id": "u1"}, limit=5, + ) + assert hits == [] + + +# --- delete / list_by_scope / delete_scope --- + +def test_delete_removes_fact(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x"), fake_embedder.embed("x")) + backend.delete("f1") + assert backend.get("f1") is None + + +def test_delete_unknown_is_noop(backend): + backend.delete("does-not-exist") # must not raise + + +def test_list_by_scope_returns_in_scope_facts(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + listed = backend.list_by_scope({"user_id": "u1"}) + assert {f.id for f in listed} == {"f1", "f2"} + + +def test_list_by_scope_respects_limit(backend, fake_embedder): + for i in range(5): + backend.insert(_fact(f"f{i}", "u1", f"x{i}"), fake_embedder.embed(f"x{i}")) + assert len(backend.list_by_scope({"user_id": "u1"}, limit=2)) == 2 + + +# --- update --- + + +def test_update_replaces_content_hash_embedding_and_timestamp(backend, fake_embedder): + original = _fact("f1", "u1", "hello", content_hash="h-old") + backend.insert(original, fake_embedder.embed("hello")) + + new_time = original.updated_at + timedelta(seconds=5) + backend.update( + "f1", + content="hello world", + content_hash="h-new", + embedding=fake_embedder.embed("hello world"), + metadata={"category": "greeting"}, + updated_at=new_time, + ) + + updated = backend.get("f1") + assert updated.content == "hello world" + assert updated.hash == "h-new" + assert updated.metadata == {"category": "greeting"} + assert updated.updated_at == new_time + assert updated.id == original.id + assert updated.created_at == original.created_at + + hits = backend.search( + query_embedding=fake_embedder.embed("hello world"), + scope={"user_id": "u1"}, + limit=1, + ) + assert hits[0][0].content == "hello world" + + +def test_update_unknown_is_noop(backend, fake_embedder): + backend.update( + "does-not-exist", + content="x", + content_hash="h", + embedding=fake_embedder.embed("x"), + metadata={}, + updated_at=datetime.now(UTC), + ) # must not raise + assert backend.get("does-not-exist") is None + + +def test_delete_scope_removes_all_in_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + deleted = backend.delete_scope({"user_id": "u1"}) + assert deleted == 2 + assert backend.list_by_scope({"user_id": "u1"}) == [] + assert len(backend.list_by_scope({"user_id": "u2"})) == 1 + + +def test_delete_scope_empty_scope_deletes_everything(backend, fake_embedder): + """Contract: empty scope = "match every fact" — same for all backends.""" + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u2", "b"), fake_embedder.embed("b")) + deleted = backend.delete_scope({}) + assert deleted == 2 + assert backend.list_by_scope({}) == [] diff --git a/tests/unit/memory/long_term/test_long_term_memory.py b/tests/unit/memory/long_term/test_long_term_memory.py new file mode 100644 index 000000000..436f86cf2 --- /dev/null +++ b/tests/unit/memory/long_term/test_long_term_memory.py @@ -0,0 +1,179 @@ +"""Tests for the long-term memory backend operations (remember/recall/forget/...).""" +import pytest + +from dynamiq.memory.long_term import LongTermMemoryError, RememberOutcome +from dynamiq.memory.long_term.backends.in_memory import InMemoryLongTermMemoryBackend + +# --- remember --- + + +def test_remember_returns_a_fact_and_persists_it(backend, user_id): + fact, outcome = backend.remember(content="User likes pizza", user_id=user_id) + assert outcome == RememberOutcome.CREATED + assert fact.id + assert fact.content == "User likes pizza" + assert fact.user_id == user_id + assert backend.get(fact.id) == fact + + +def test_remember_exact_duplicate_returns_unchanged(backend, user_id): + first, first_outcome = backend.remember(content="User likes pizza", user_id=user_id) + second, second_outcome = backend.remember(content="User likes pizza", user_id=user_id) + assert first_outcome == RememberOutcome.CREATED + assert second_outcome == RememberOutcome.UNCHANGED + assert first.id == second.id + + +def test_remember_does_not_dedup_across_users(backend, user_id, other_user_id): + a, _ = backend.remember(content="User likes pizza", user_id=user_id) + b, b_outcome = backend.remember(content="User likes pizza", user_id=other_user_id) + assert b_outcome == RememberOutcome.CREATED + assert a.id != b.id + assert a.user_id != b.user_id + + +def test_remember_normalises_whitespace_for_dedup(backend, user_id): + a, _ = backend.remember(content=" User likes pizza ", user_id=user_id) + b, b_outcome = backend.remember(content="USER LIKES PIZZA", user_id=user_id) + assert b_outcome == RememberOutcome.UNCHANGED + assert a.id == b.id + + +def test_remember_paraphrase_upserts_existing(fake_embedder, user_id): + """With a low threshold, a near-similar fact replaces the earlier one in place.""" + backend = InMemoryLongTermMemoryBackend(embedder=fake_embedder, upsert_threshold=0.0) + original, _ = backend.remember(content="User likes pizza", user_id=user_id) + updated, outcome = backend.remember(content="User loves pizza", user_id=user_id) + + assert outcome == RememberOutcome.UPDATED + assert updated.id == original.id + assert updated.content == "User loves pizza" + assert backend.get(original.id).content == "User loves pizza" + assert len(backend.list_all(user_id=user_id)) == 1 + + +def test_remember_distinct_content_inserts_new_when_threshold_high(backend, user_id): + """Default high threshold (0.85) keeps unrelated facts separate.""" + a, _ = backend.remember(content="User likes pizza", user_id=user_id) + b, outcome = backend.remember(content="User dislikes mushrooms", user_id=user_id) + assert outcome == RememberOutcome.CREATED + assert a.id != b.id + assert len(backend.list_all(user_id=user_id)) == 2 + + +def test_upsert_replaces_metadata_when_provided(fake_embedder, user_id): + """A corrected fact's new metadata must overwrite the old fact's metadata.""" + backend = InMemoryLongTermMemoryBackend(embedder=fake_embedder, upsert_threshold=0.0) + original, _ = backend.remember(content="User likes pizza", user_id=user_id, metadata={"category": "food"}) + updated, outcome = backend.remember( + content="User loves pizza", user_id=user_id, metadata={"category": "preference"} + ) + assert outcome == RememberOutcome.UPDATED + assert updated.id == original.id + assert updated.metadata == {"category": "preference"} + assert backend.get(original.id).metadata == {"category": "preference"} + + +def test_upsert_preserves_metadata_when_omitted(fake_embedder, user_id): + """When the corrected call passes no metadata, the old metadata is kept.""" + backend = InMemoryLongTermMemoryBackend(embedder=fake_embedder, upsert_threshold=0.0) + original, _ = backend.remember(content="User likes pizza", user_id=user_id, metadata={"category": "food"}) + updated, _ = backend.remember(content="User loves pizza", user_id=user_id) + assert updated.id == original.id + assert updated.metadata == {"category": "food"} + + +def test_remember_rejects_empty_content(backend, user_id): + with pytest.raises(LongTermMemoryError): + backend.remember(content=" ", user_id=user_id) + + +def test_remember_stores_metadata(backend, user_id): + fact, _ = backend.remember(content="x", user_id=user_id, metadata={"category": "preference"}) + assert backend.get(fact.id).metadata == {"category": "preference"} + + +# --- recall --- + + +def test_recall_returns_scored_facts(backend, user_id): + backend.remember(content="User likes pizza", user_id=user_id) + backend.remember(content="User dislikes mushrooms", user_id=user_id) + hits = backend.recall(query="pizza preferences", user_id=user_id, limit=2) + assert len(hits) == 2 + fact, score = hits[0] + assert fact.content + assert isinstance(score, float) + + +def test_recall_isolates_users(backend, user_id, other_user_id): + backend.remember(content="A's fact", user_id=user_id) + backend.remember(content="B's fact", user_id=other_user_id) + hits = backend.recall(query="fact", user_id=user_id, limit=5) + assert all(f.user_id == user_id for f, _ in hits) + + +def test_recall_respects_limit(backend, user_id): + for i in range(5): + backend.remember(content=f"fact-{i}", user_id=user_id) + hits = backend.recall(query="fact", user_id=user_id, limit=2) + assert len(hits) == 2 + + +def test_recall_empty_store_returns_empty(backend, user_id): + assert backend.recall(query="anything", user_id=user_id, limit=5) == [] + + +def test_recall_rejects_empty_query(backend, user_id): + with pytest.raises(LongTermMemoryError): + backend.recall(query=" ", user_id=user_id, limit=5) + + +# --- forget (programmatic API; not exposed to agents) --- + + +def test_forget_deletes_known_fact(backend, user_id): + fact, _ = backend.remember(content="x", user_id=user_id) + assert backend.forget(fact_id=fact.id, user_id=user_id) == "deleted" + assert backend.get(fact.id) is None + + +def test_forget_unknown_returns_not_found(backend, user_id): + assert backend.forget(fact_id="does-not-exist", user_id=user_id) == "not_found" + + +def test_forget_cross_user_returns_forbidden(backend, user_id, other_user_id): + fact, _ = backend.remember(content="x", user_id=user_id) + result = backend.forget(fact_id=fact.id, user_id=other_user_id) + assert result == "forbidden" + assert backend.get(fact.id) is not None + + +# --- admin / introspection --- + + +def test_list_all_returns_user_facts(backend, user_id, other_user_id): + backend.remember(content="a", user_id=user_id) + backend.remember(content="b", user_id=user_id) + backend.remember(content="c", user_id=other_user_id) + facts = backend.list_all(user_id=user_id) + assert {f.content for f in facts} == {"a", "b"} + + +def test_get_returns_fact_by_id(backend, user_id): + fact, _ = backend.remember(content="x", user_id=user_id) + assert backend.get(fact.id) == fact + + +def test_get_unknown_returns_none(backend): + assert backend.get("nope") is None + + +def test_clear_user_deletes_all_user_facts(backend, user_id, other_user_id): + backend.remember(content="a", user_id=user_id) + backend.remember(content="b", user_id=user_id) + backend.remember(content="c", user_id=other_user_id) + deleted = backend.clear_user(user_id=user_id) + assert deleted == 2 + assert backend.list_all(user_id=user_id) == [] + assert len(backend.list_all(user_id=other_user_id)) == 1 diff --git a/tests/unit/memory/long_term/test_pinecone_backend.py b/tests/unit/memory/long_term/test_pinecone_backend.py new file mode 100644 index 000000000..36f51797c --- /dev/null +++ b/tests/unit/memory/long_term/test_pinecone_backend.py @@ -0,0 +1,273 @@ +"""Tests for PineconeLongTermMemoryBackend. + +These exercise the backend's storage primitives against an in-process fake +Pinecone index that mimics the v3 client API used by the backend +(`upsert`, `fetch`, `query`, `delete`). No live Pinecone calls are made. +""" +import math +from datetime import UTC, datetime + +import pytest + +from dynamiq.connections import Pinecone as PineconeConnection +from dynamiq.memory.long_term.backends.pinecone import PineconeLongTermMemoryBackend +from dynamiq.memory.long_term.schemas import Fact + +# --- Fake Pinecone client / index ------------------------------------------ + +# Pinecone metadata filter is MongoDB-style. We support the subset the backend +# emits: `{key: {"$eq": value}}` and `{"$and": [..., ...]}`. + + +def _matches_filter(metadata: dict, flt: dict | None) -> bool: + if not flt: + return True + if "$and" in flt: + return all(_matches_filter(metadata, sub) for sub in flt["$and"]) + for key, predicate in flt.items(): + if isinstance(predicate, dict) and "$eq" in predicate: + if metadata.get(key) != predicate["$eq"]: + return False + else: + if metadata.get(key) != predicate: + return False + return True + + +def _cosine(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) or 1.0 + nb = math.sqrt(sum(x * x for x in b)) or 1.0 + return dot / (na * nb) + + +class _FakeIndex: + def __init__(self) -> None: + # namespace -> id -> {"id", "values", "metadata"} + self.store: dict[str, dict[str, dict]] = {} + + def _ns(self, namespace: str) -> dict[str, dict]: + return self.store.setdefault(namespace, {}) + + def upsert(self, vectors, namespace="default"): + ns = self._ns(namespace) + for vec in vectors: + ns[vec["id"]] = {"id": vec["id"], "values": list(vec["values"]), "metadata": dict(vec["metadata"])} + return {"upserted_count": len(vectors)} + + def fetch(self, ids, namespace="default"): + ns = self._ns(namespace) + return {"vectors": {i: ns[i] for i in ids if i in ns}} + + def delete(self, ids=None, namespace="default", filter=None): + ns = self._ns(namespace) + if ids is not None: + for i in ids: + ns.pop(i, None) + elif filter is not None: + for i, item in list(ns.items()): + if _matches_filter(item["metadata"], filter): + ns.pop(i, None) + return {} + + def query(self, vector, top_k, namespace="default", filter=None, include_metadata=True, **_): + ns = self._ns(namespace) + candidates = [item for item in ns.values() if _matches_filter(item["metadata"], filter)] + scored = [(item, _cosine(vector, item["values"])) for item in candidates] + scored.sort(key=lambda pair: pair[1], reverse=True) + matches = [] + for item, score in scored[:top_k]: + entry = {"id": item["id"], "score": score} + if include_metadata: + entry["metadata"] = item["metadata"] + matches.append(entry) + return {"matches": matches} + + +class _FakeClient: + def __init__(self) -> None: + self.indexes: dict[str, _FakeIndex] = {} + + def Index(self, name): # noqa: N802 — mirrors Pinecone client API + return self.indexes.setdefault(name, _FakeIndex()) + + +# --- Fixtures --------------------------------------------------------------- + + +@pytest.fixture +def fake_pinecone_client(monkeypatch): + client = _FakeClient() + monkeypatch.setattr(PineconeConnection, "connect", lambda self: client) + return client + + +@pytest.fixture +def backend(fake_embedder, fake_pinecone_client): + return PineconeLongTermMemoryBackend( + connection=PineconeConnection(api_key="test-key"), + embedder=fake_embedder, + index_name="user_facts", + namespace="test", + dimension=fake_embedder.DIM, + ) + + +def _fact(fact_id: str, user_id: str, content: str, content_hash: str | None = None) -> Fact: + now = datetime.now(UTC) + return Fact( + id=fact_id, + content=content, + hash=content_hash or f"h-{fact_id}", + user_id=user_id, + metadata={}, + created_at=now, + updated_at=now, + ) + + +# --- insert / get / get_by_hash -------------------------------------------- + + +def test_pinecone_insert_then_get(backend, fake_embedder): + fact = _fact("f1", "u1", "hello") + backend.insert(fact, fake_embedder.embed("hello")) + fetched = backend.get("f1") + assert fetched is not None and fetched.id == "f1" and fetched.content == "hello" + + +def test_pinecone_get_unknown_returns_none(backend): + assert backend.get("does-not-exist") is None + + +def test_pinecone_get_by_hash(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + found = backend.get_by_hash(user_id="u1", content_hash="h-shared") + assert found is not None and found.id == "f1" + + +def test_pinecone_get_by_hash_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + assert backend.get_by_hash(user_id="u2", content_hash="h-shared") is None + + +def test_pinecone_metadata_round_trip(backend, fake_embedder): + """Free-form metadata must survive Pinecone's flat-schema constraint via JSON encoding.""" + fact = _fact("f1", "u1", "x").model_copy( + update={"metadata": {"category": "preference", "score": 0.8}} + ) + backend.insert(fact, fake_embedder.embed("x")) + fetched = backend.get("f1") + assert fetched.metadata == {"category": "preference", "score": 0.8} + + +# --- delete / list_by_scope / delete_scope --------------------------------- + + +def test_pinecone_delete(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x"), fake_embedder.embed("x")) + backend.delete("f1") + assert backend.get("f1") is None + + +def test_pinecone_list_by_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + listed = backend.list_by_scope({"user_id": "u1"}) + assert {f.id for f in listed} == {"f1", "f2"} + + +def test_pinecone_delete_scope_returns_accurate_count(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + assert backend.delete_scope({"user_id": "u1"}) == 2 + assert backend.list_by_scope({"user_id": "u1"}) == [] + assert len(backend.list_by_scope({"user_id": "u2"})) == 1 + + +def test_pinecone_delete_scope_empty_returns_zero(backend): + assert backend.delete_scope({"user_id": "nobody"}) == 0 + + +def test_pinecone_delete_scope_paginates_beyond_single_page(backend, fake_embedder, monkeypatch): + """clear_user on users with more facts than fit in one query page must still + delete everything and report the true count — not silently cap at one page.""" + monkeypatch.setattr(backend, "_LIST_PAGE_SIZE", 2) + for i in range(5): + backend.insert(_fact(f"f{i}", "u1", f"c{i}"), fake_embedder.embed(f"c{i}")) + assert backend.delete_scope({"user_id": "u1"}) == 5 + assert backend.list_by_scope({"user_id": "u1"}) == [] + + +# --- search ---------------------------------------------------------------- + + +def test_pinecone_search_relevance_ordered(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u1", "alpha-2"), fake_embedder.embed("alpha-2")) + backend.insert(_fact("f3", "u1", "zulu"), fake_embedder.embed("zulu")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=3, + ) + assert hits[0][0].id == "f1" + scores = [score for _, score in hits] + assert scores == sorted(scores, reverse=True) + + +def test_pinecone_search_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u2", "alpha"), fake_embedder.embed("alpha")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=5, + ) + assert [fact.id for fact, _ in hits] == ["f1"] + + +def test_pinecone_search_empty_returns_empty(backend, fake_embedder): + hits = backend.search( + query_embedding=fake_embedder.embed("x"), + scope={"user_id": "u1"}, + limit=5, + ) + assert hits == [] + + +# --- high-level operations via Template Method ------------------------------ + + +def test_pinecone_remember_and_recall_through_backend(backend, fake_embedder): + """End-to-end remember/recall must work through the backend's high-level API, + confirming the storage primitives are wired correctly.""" + backend.remember(content="User likes pizza", user_id="u1") + backend.remember(content="User likes Python", user_id="u1") + hits = backend.recall(query="pizza preferences", user_id="u1", limit=5) + contents = {fact.content for fact, _ in hits} + assert {"User likes pizza", "User likes Python"} <= contents + + +# --- serialization --------------------------------------------------------- + + +def test_pinecone_to_dict_excludes_live_clients_and_includes_connection(backend): + """`to_dict` must drop the runtime client/index but emit connection + embedder + so the YAML round-trip rebuilds an equivalent backend.""" + data = backend.to_dict() + assert "_client" not in data and "_index" not in data + assert isinstance(data["connection"], dict) + assert isinstance(data["embedder"], dict) + # Persistent backend identity must survive serialization. + assert data["index_name"] == "user_facts" + assert data["namespace"] == "test" + + +def test_pinecone_to_dict_accepts_include_secure_params(backend): + """`include_secure_params=True` must propagate through backend → connection + without raising.""" + data = backend.to_dict(include_secure_params=True) + assert "connection" in data and "embedder" in data diff --git a/tests/unit/memory/long_term/test_schemas.py b/tests/unit/memory/long_term/test_schemas.py new file mode 100644 index 000000000..820cddef8 --- /dev/null +++ b/tests/unit/memory/long_term/test_schemas.py @@ -0,0 +1,25 @@ +from datetime import UTC, datetime + +from dynamiq.memory.long_term.schemas import Fact + + +def test_fact_round_trip(): + now = datetime.now(UTC) + fact = Fact( + id="f1", + content="User prefers terse responses", + hash="abcd1234", + user_id="u1", + metadata={"category": "preference"}, + created_at=now, + updated_at=now, + ) + dumped = fact.model_dump() + assert Fact(**dumped) == fact + + +def test_fact_metadata_defaults_to_empty_dict(): + now = datetime.now(UTC) + fact = Fact(id="f1", content="x", hash="h", user_id="u", + created_at=now, updated_at=now) + assert fact.metadata == {} diff --git a/tests/unit/memory/long_term/test_tools.py b/tests/unit/memory/long_term/test_tools.py new file mode 100644 index 000000000..073cb13dd --- /dev/null +++ b/tests/unit/memory/long_term/test_tools.py @@ -0,0 +1,226 @@ +from dynamiq.memory.long_term.backends.in_memory import InMemoryLongTermMemoryBackend +from dynamiq.nodes.tools.long_term_memory import RecallFactsTool, RememberFactTool, build_long_term_memory_tools + + +# --- RememberFactTool --- + + +def test_remember_tool_persists_a_fact(backend, user_id): + tool = RememberFactTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(content="User likes pizza")) + fact_id = result["content"]["fact_id"] + assert result["content"]["outcome"] == "created" + assert backend.get(fact_id).content == "User likes pizza" + + +def test_remember_tool_input_schema_has_no_user_id(): + """LLM-visible signature must not contain user_id — it's instance state.""" + assert "user_id" not in RememberFactTool.input_schema.model_fields + assert {"content", "metadata"} <= set(RememberFactTool.input_schema.model_fields) + + +def test_remember_tool_uses_construction_user_id(backend, user_id): + tool = RememberFactTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(content="x")) + fact = backend.get(result["content"]["fact_id"]) + assert fact.user_id == user_id + + +def test_remember_tool_idempotent_on_duplicate(backend, user_id): + tool = RememberFactTool(backend=backend, user_id=user_id) + a = tool.execute(tool.input_schema(content="x")) + b = tool.execute(tool.input_schema(content="x")) + assert a["content"]["fact_id"] == b["content"]["fact_id"] + assert b["content"]["outcome"] == "unchanged" + + +def test_remember_tool_accepts_metadata(backend, user_id): + tool = RememberFactTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(content="x", metadata={"category": "preference"})) + fact = backend.get(result["content"]["fact_id"]) + assert fact.metadata == {"category": "preference"} + + +def test_remember_tool_agent_optimized_returns_status_string(backend, user_id): + """Agent-mode output is a short human-readable status, not a dict.""" + tool = RememberFactTool(backend=backend, user_id=user_id) + tool.is_optimized_for_agents = True + + created = tool.execute(tool.input_schema(content="User likes pizza")) + assert created["content"] == "Fact saved." + + unchanged = tool.execute(tool.input_schema(content="User likes pizza")) + assert unchanged["content"] == "Already remembered." + + +def test_remember_tool_agent_optimized_reports_update(fake_embedder, user_id): + """Agent-mode upsert renders as 'Fact updated.'""" + backend = InMemoryLongTermMemoryBackend(embedder=fake_embedder, upsert_threshold=0.0) + tool = RememberFactTool(backend=backend, user_id=user_id) + tool.is_optimized_for_agents = True + + tool.execute(tool.input_schema(content="User likes pizza")) + updated = tool.execute(tool.input_schema(content="User loves pizza")) + assert updated["content"] == "Fact updated." + + +# --- RecallFactsTool --- + + +def test_recall_tool_returns_hits(backend, user_id): + backend.remember(content="User likes pizza", user_id=user_id) + backend.remember(content="User likes Python", user_id=user_id) + tool = RecallFactsTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(queries=["pizza"], limit=2)) + items = result["content"] + assert len(items) == 2 + for item in items: + assert {"content", "score"} <= set(item.keys()) + scores = [it["score"] for it in items] + assert scores == sorted(scores, reverse=True) + + +def test_recall_tool_input_schema_has_no_user_id(): + assert "user_id" not in RecallFactsTool.input_schema.model_fields + assert {"queries", "limit"} <= set(RecallFactsTool.input_schema.model_fields) + + +def test_recall_tool_isolates_users(backend, user_id, other_user_id): + backend.remember(content="A's fact", user_id=user_id) + backend.remember(content="B's fact", user_id=other_user_id) + tool = RecallFactsTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(queries=["fact"], limit=5)) + contents = {item["content"] for item in result["content"]} + assert contents == {"A's fact"} + + +def test_recall_tool_empty_store_returns_empty(backend, user_id): + tool = RecallFactsTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(queries=["anything"])) + assert result["content"] == [] + + +def test_recall_tool_agent_optimized_returns_bullet_list(backend, user_id): + backend.remember(content="User likes pizza", user_id=user_id) + backend.remember(content="User likes Python", user_id=user_id) + tool = RecallFactsTool(backend=backend, user_id=user_id) + tool.is_optimized_for_agents = True + result = tool.execute(tool.input_schema(queries=["pizza"], limit=2)) + assert isinstance(result["content"], str) + assert "- User likes pizza" in result["content"] + assert "- User likes Python" in result["content"] + + +def test_recall_tool_agent_optimized_empty_message(backend, user_id): + tool = RecallFactsTool(backend=backend, user_id=user_id) + tool.is_optimized_for_agents = True + result = tool.execute(tool.input_schema(queries=["anything"])) + assert result["content"] == "No relevant facts." + + +def test_recall_tool_merges_multiple_queries_and_dedupes(backend, user_id): + """Multiple phrasings hitting the same fact must yield one entry, not duplicates.""" + backend.remember(content="User likes pizza", user_id=user_id) + backend.remember(content="User likes Python", user_id=user_id) + tool = RecallFactsTool(backend=backend, user_id=user_id) + result = tool.execute( + tool.input_schema(queries=["pizza", "User likes pizza", "favourite food"], limit=5) + ) + contents = [it["content"] for it in result["content"]] + assert len(contents) == len(set(contents)) + assert "User likes pizza" in contents + + +def test_recall_tool_rejects_empty_queries_list(): + """min_length=1 on `queries` must reject an empty list at schema validation.""" + import pytest as _pytest + + with _pytest.raises(Exception): + RecallFactsTool.input_schema(queries=[]) + + +def test_recall_tool_rejects_whitespace_only_query(): + """A blank or whitespace-only entry must be caught at validation time, not + when the backend raises mid-execute.""" + import pytest as _pytest + + with _pytest.raises(Exception): + RecallFactsTool.input_schema(queries=[" "]) + with _pytest.raises(Exception): + RecallFactsTool.input_schema(queries=["valid", ""]) + + +def test_recall_tool_strips_query_whitespace(backend, user_id): + """Surrounding whitespace must be stripped so leading/trailing spaces don't + affect the embedding (or cause spurious cache misses).""" + backend.remember(content="User likes pizza", user_id=user_id) + tool = RecallFactsTool(backend=backend, user_id=user_id) + result = tool.execute(tool.input_schema(queries=[" pizza "])) + assert result["content"], "stripped query should still match the stored fact" + + +# --- factory --- + + +def test_factory_builds_default_two_tools(backend, user_id): + tools = build_long_term_memory_tools(backend=backend, user_id=user_id) + assert {t.name for t in tools} == {"remember_fact", "recall_facts"} + + +def test_factory_respects_include(backend, user_id): + tools = build_long_term_memory_tools(backend=backend, user_id=user_id, include=("recall",)) + assert [t.name for t in tools] == ["recall_facts"] + + +def test_factory_bakes_user_id_into_each_tool(backend, user_id): + tools = build_long_term_memory_tools(backend=backend, user_id=user_id) + for tool in tools: + assert tool.user_id == user_id + + +def test_factory_ignores_unknown_include_keys(backend, user_id): + tools = build_long_term_memory_tools(backend=backend, user_id=user_id, include=("recall", "unknown", "forget")) + assert [t.name for t in tools] == ["recall_facts"] + + +def test_factory_skips_enum_members_missing_from_builders(backend, user_id, monkeypatch): + """Valid `MemoryToolKind` values without a `_TOOL_BUILDERS` entry must be + silently skipped, not KeyError. Mirrors the unknown-string branch so the + docstring's "unknown keys are ignored" promise actually holds.""" + from dynamiq.memory.long_term.types import MemoryToolKind + from dynamiq.nodes.tools import long_term_memory as ltm_tools_module + + monkeypatch.setattr(ltm_tools_module, "_TOOL_BUILDERS", {MemoryToolKind.RECALL: ltm_tools_module.RecallFactsTool}) + tools = build_long_term_memory_tools( + backend=backend, + user_id=user_id, + include=(MemoryToolKind.REMEMBER, MemoryToolKind.RECALL), + ) + assert [t.name for t in tools] == ["recall_facts"] + + +# --- serialization --- + + +def test_remember_tool_to_dict_excludes_live_backend(backend, user_id): + """`to_dict` must not auto-dump `backend` (it holds runtime clients + embedder). + + The default `model_dump` would try to JSON-encode the embedder's connection + and the backend's live client, blowing up tracing callbacks. The tool base + excludes the field and re-adds it via `LongTermMemoryBackend.to_dict()`. + """ + tool = RememberFactTool(backend=backend, user_id=user_id) + data = tool.to_dict() + assert "backend" in data + backend_dump = data["backend"] + assert isinstance(backend_dump, dict) + assert "embedder" in backend_dump and isinstance(backend_dump["embedder"], dict) + + +def test_remember_tool_to_dict_accepts_include_secure_params(backend, user_id): + """`include_secure_params=True` must propagate through tool → backend → connection + without raising. Connection.to_dict swallows the kwarg; backends pass it through.""" + tool = RememberFactTool(backend=backend, user_id=user_id) + data = tool.to_dict(include_secure_params=True) + assert "backend" in data + assert "embedder" in data["backend"] diff --git a/tests/unit/memory/long_term/test_weaviate_backend.py b/tests/unit/memory/long_term/test_weaviate_backend.py new file mode 100644 index 000000000..d2ec3f9f4 --- /dev/null +++ b/tests/unit/memory/long_term/test_weaviate_backend.py @@ -0,0 +1,376 @@ +"""Tests for WeaviateLongTermMemoryBackend. + +These exercise the backend's storage primitives against an in-process fake +Weaviate v4 collection that mimics the subset of the API the backend uses. +No live Weaviate calls are made. +""" +import math +from datetime import UTC, datetime +from types import SimpleNamespace + +import pytest + +from dynamiq.connections import Weaviate as WeaviateConnection +from dynamiq.memory.long_term.backends.weaviate import ( + WeaviateLongTermMemoryBackend, + _to_weaviate_uuid, +) +from dynamiq.memory.long_term.schemas import Fact + +# --- Fake Weaviate client / collection ------------------------------------- + + +def _cosine_distance(a: list[float], b: list[float]) -> float: + dot = sum(x * y for x, y in zip(a, b)) + na = math.sqrt(sum(x * x for x in a)) or 1.0 + nb = math.sqrt(sum(x * x for x in b)) or 1.0 + return 1.0 - dot / (na * nb) + + +class _FakeData: + def __init__(self, store: dict) -> None: + self.store = store + + def insert(self, *, uuid, properties, vector): + self.store[uuid] = {"uuid": uuid, "properties": dict(properties), "vector": list(vector)} + + def replace(self, *, uuid, properties, vector): + self.store[uuid] = {"uuid": uuid, "properties": dict(properties), "vector": list(vector)} + + def delete_by_id(self, *, uuid): + self.store.pop(uuid, None) + + def delete_many(self, *, where): + for uid, item in list(self.store.items()): + if where.matches(item): + self.store.pop(uid, None) + + +class _FakeQuery: + def __init__(self, store: dict) -> None: + self.store = store + + def fetch_object_by_id(self, *, uuid): + item = self.store.get(uuid) + if item is None: + return None + return SimpleNamespace(uuid=uuid, properties=item["properties"], metadata=None) + + def fetch_objects(self, *, filters=None, limit=10): + candidates = [ + SimpleNamespace(uuid=uid, properties=item["properties"], metadata=None) + for uid, item in self.store.items() + if filters is None or filters.matches(item) + ] + return SimpleNamespace(objects=candidates[:limit]) + + def near_vector(self, *, near_vector, limit=10, filters=None, return_metadata=None): + candidates = [ + (uid, item, _cosine_distance(near_vector, item["vector"])) + for uid, item in self.store.items() + if filters is None or filters.matches(item) + ] + candidates.sort(key=lambda t: t[2]) + return SimpleNamespace( + objects=[ + SimpleNamespace( + uuid=uid, + properties=item["properties"], + metadata=SimpleNamespace(distance=dist), + ) + for uid, item, dist in candidates[:limit] + ] + ) + + +class _FakeCollection: + def __init__(self) -> None: + self.store: dict[str, dict] = {} + self.data = _FakeData(self.store) + self.query = _FakeQuery(self.store) + + +class _FakeCollections: + def __init__(self) -> None: + self.collections: dict[str, _FakeCollection] = {} + + def get(self, name): + return self.collections.setdefault(name, _FakeCollection()) + + def exists(self, name): + return name in self.collections + + def create(self, *, name, **_): + self.collections.setdefault(name, _FakeCollection()) + + def delete(self, name): + self.collections.pop(name, None) + + +class _FakeClient: + def __init__(self) -> None: + self.collections = _FakeCollections() + + +# We bypass the real weaviate `Filter` objects entirely — the backend's +# `_scope_to_filter` builds them via `Filter.by_property(...).equal(...)`, +# which calls into the weaviate library. For mock tests we monkeypatch the +# scope-to-filter helper to return a callable predicate the fakes can evaluate. + + +class _PredicateFilter: + def __init__(self, predicate) -> None: + self._predicate = predicate + + def matches(self, item) -> bool: + return self._predicate(item) + + def __and__(self, other): + return _PredicateFilter(lambda item: self._predicate(item) and other._predicate(item)) + + +def _fake_scope_to_filter(scope: dict): + if not scope: + return None + return _PredicateFilter(lambda item: all(item["properties"].get(k) == v for k, v in scope.items())) + + +def _fake_id_in_filter(uuids): + uuid_set = set(uuids) + return _PredicateFilter(lambda item: item["uuid"] in uuid_set) + + +# --- Fixtures --------------------------------------------------------------- + + +@pytest.fixture +def fake_weaviate_client(monkeypatch): + client = _FakeClient() + monkeypatch.setattr(WeaviateConnection, "connect", lambda self: client) + # Swap in a fake scope_to_filter so the backend uses our predicate fakes + # instead of real weaviate Filter objects (which the fake store can't evaluate). + import dynamiq.memory.long_term.backends.weaviate as weaviate_backend + + monkeypatch.setattr(weaviate_backend, "_scope_to_filter", _fake_scope_to_filter) + monkeypatch.setattr(weaviate_backend, "_id_in_filter", _fake_id_in_filter) + return client + + +@pytest.fixture +def backend(fake_embedder, fake_weaviate_client): + backend = WeaviateLongTermMemoryBackend( + connection=WeaviateConnection(api_key="test-key", url="http://localhost"), + embedder=fake_embedder, + collection_name="UserFacts", + dimension=fake_embedder.DIM, + ) + # The real backend's `model_post_init` already called `collections.get`, + # which our fake auto-creates on access — so the collection is ready. + return backend + + +def _fact(fact_id: str, user_id: str, content: str, content_hash: str | None = None) -> Fact: + now = datetime.now(UTC) + return Fact( + id=fact_id, + content=content, + hash=content_hash or f"h-{fact_id}", + user_id=user_id, + metadata={}, + created_at=now, + updated_at=now, + ) + + +# --- insert / get / get_by_hash -------------------------------------------- + + +def test_weaviate_insert_then_get(backend, fake_embedder): + fact = _fact("f1", "u1", "hello") + backend.insert(fact, fake_embedder.embed("hello")) + fetched = backend.get("f1") + assert fetched is not None and fetched.id == "f1" and fetched.content == "hello" + + +def test_weaviate_get_unknown_returns_none(backend): + assert backend.get("does-not-exist") is None + + +def test_weaviate_get_by_hash(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + found = backend.get_by_hash(user_id="u1", content_hash="h-shared") + assert found is not None and found.id == "f1" + + +def test_weaviate_get_by_hash_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x", "h-shared"), fake_embedder.embed("x")) + assert backend.get_by_hash(user_id="u2", content_hash="h-shared") is None + + +def test_weaviate_metadata_round_trip(backend, fake_embedder): + """Free-form metadata must round-trip through the JSON-encoded property.""" + fact = _fact("f1", "u1", "x").model_copy(update={"metadata": {"category": "preference", "score": 0.8}}) + backend.insert(fact, fake_embedder.embed("x")) + assert backend.get("f1").metadata == {"category": "preference", "score": 0.8} + + +def test_weaviate_construction_does_not_touch_collection(fake_embedder, monkeypatch): + """A fresh backend must construct cleanly without resolving the collection — + that lookup is deferred to first use so `ensure_collection()` can run after.""" + + class _StrictCollections: + def __init__(self) -> None: + self.get_called_with: list = [] + + def get(self, name): + self.get_called_with.append(name) + return _FakeCollection() + + class _StrictClient: + def __init__(self) -> None: + self.collections = _StrictCollections() + + client = _StrictClient() + monkeypatch.setattr(WeaviateConnection, "connect", lambda self: client) + backend = WeaviateLongTermMemoryBackend( + connection=WeaviateConnection(api_key="k", url="http://localhost"), + embedder=fake_embedder, + collection_name="UserFacts", + dimension=fake_embedder.DIM, + ) + assert client.collections.get_called_with == [] # not yet resolved + _ = backend._collection # first access resolves + assert client.collections.get_called_with == ["UserFacts"] + + +def test_weaviate_fact_id_maps_to_deterministic_uuid(): + """Two backends must resolve the same fact_id to the same UUID — so a fact + inserted by one process can be deleted by another via the original id.""" + assert _to_weaviate_uuid("fact-1") == _to_weaviate_uuid("fact-1") + assert _to_weaviate_uuid("fact-1") != _to_weaviate_uuid("fact-2") + + +# --- delete / list_by_scope / delete_scope --------------------------------- + + +def test_weaviate_delete(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "x"), fake_embedder.embed("x")) + backend.delete("f1") + assert backend.get("f1") is None + + +def test_weaviate_update_replaces_content_and_vector(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "old"), fake_embedder.embed("old")) + backend.update( + "f1", + content="new", + content_hash="h-new", + embedding=fake_embedder.embed("new"), + metadata={"k": "v"}, + updated_at=datetime.now(UTC), + ) + fetched = backend.get("f1") + assert fetched.content == "new" and fetched.hash == "h-new" and fetched.metadata == {"k": "v"} + + +def test_weaviate_list_by_scope(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + listed = backend.list_by_scope({"user_id": "u1"}) + assert {f.id for f in listed} == {"f1", "f2"} + + +def test_weaviate_delete_scope_returns_accurate_count(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "a"), fake_embedder.embed("a")) + backend.insert(_fact("f2", "u1", "b"), fake_embedder.embed("b")) + backend.insert(_fact("f3", "u2", "c"), fake_embedder.embed("c")) + assert backend.delete_scope({"user_id": "u1"}) == 2 + assert backend.list_by_scope({"user_id": "u1"}) == [] + assert len(backend.list_by_scope({"user_id": "u2"})) == 1 + + +def test_weaviate_delete_scope_empty_returns_zero(backend): + assert backend.delete_scope({"user_id": "nobody"}) == 0 + + +def test_weaviate_delete_scope_paginates_beyond_single_page_with_scope(backend, fake_embedder, monkeypatch): + """A scoped delete must remove every match and return the true count even + when the matched set exceeds Weaviate's per-call fetch cap.""" + monkeypatch.setattr(type(backend), "_SCOPE_PAGE_SIZE", 2) + for i in range(5): + backend.insert(_fact(f"f{i}", "u1", f"c{i}"), fake_embedder.embed(f"c{i}")) + assert backend.delete_scope({"user_id": "u1"}) == 5 + assert backend.list_by_scope({"user_id": "u1"}) == [] + + +def test_weaviate_delete_scope_empty_paginates_unbounded(backend, fake_embedder, monkeypatch): + """Empty scope must clear the entire collection — not just the first page.""" + monkeypatch.setattr(type(backend), "_SCOPE_PAGE_SIZE", 2) + for i in range(5): + backend.insert(_fact(f"f{i}", f"u{i % 2}", f"c{i}"), fake_embedder.embed(f"c{i}")) + assert backend.delete_scope({}) == 5 + + +# --- search ---------------------------------------------------------------- + + +def test_weaviate_search_relevance_ordered(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u1", "alpha-2"), fake_embedder.embed("alpha-2")) + backend.insert(_fact("f3", "u1", "zulu"), fake_embedder.embed("zulu")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=3, + ) + assert hits[0][0].id == "f1" + scores = [score for _, score in hits] + assert scores == sorted(scores, reverse=True) + + +def test_weaviate_search_isolates_users(backend, fake_embedder): + backend.insert(_fact("f1", "u1", "alpha"), fake_embedder.embed("alpha")) + backend.insert(_fact("f2", "u2", "alpha"), fake_embedder.embed("alpha")) + hits = backend.search( + query_embedding=fake_embedder.embed("alpha"), + scope={"user_id": "u1"}, + limit=5, + ) + assert [fact.id for fact, _ in hits] == ["f1"] + + +def test_weaviate_search_empty_returns_empty(backend, fake_embedder): + hits = backend.search( + query_embedding=fake_embedder.embed("x"), + scope={"user_id": "u1"}, + limit=5, + ) + assert hits == [] + + +# --- high-level operations via Template Method ------------------------------ + + +def test_weaviate_remember_and_recall_through_backend(backend): + backend.remember(content="User likes pizza", user_id="u1") + backend.remember(content="User likes Python", user_id="u1") + hits = backend.recall(query="pizza preferences", user_id="u1", limit=5) + contents = {fact.content for fact, _ in hits} + assert {"User likes pizza", "User likes Python"} <= contents + + +# --- serialization --------------------------------------------------------- + + +def test_weaviate_to_dict_excludes_live_clients_and_includes_connection(backend): + data = backend.to_dict() + assert "_client" not in data and "_collection" not in data + assert isinstance(data["connection"], dict) + assert isinstance(data["embedder"], dict) + assert data["collection_name"] == "UserFacts" + + +def test_weaviate_to_dict_accepts_include_secure_params(backend): + data = backend.to_dict(include_secure_params=True) + assert "connection" in data and "embedder" in data diff --git a/tests/unit/nodes/agents/test_long_term_memory_integration.py b/tests/unit/nodes/agents/test_long_term_memory_integration.py new file mode 100644 index 000000000..9b0e7ac5a --- /dev/null +++ b/tests/unit/nodes/agents/test_long_term_memory_integration.py @@ -0,0 +1,378 @@ +import hashlib +from types import SimpleNamespace +from typing import ClassVar +from unittest.mock import patch + +import pytest + +from dynamiq.connections import BaseConnection +from dynamiq.connections import OpenAI as OpenAIConnection +from dynamiq.memory.long_term import LongTermMemoryConfig +from dynamiq.memory.long_term.backends.in_memory import InMemoryLongTermMemoryBackend +from dynamiq.nodes.agents.base import Agent +from dynamiq.nodes.embedders.base import TextEmbedder, TextEmbedderInputSchema +from dynamiq.nodes.llms import OpenAI + + +class _StubConnection(BaseConnection): + def connect(self) -> None: + return None + + +class _FakeEmbedder(TextEmbedder): + name: str = "fake-text-embedder" + connection: BaseConnection = _StubConnection() + DIM: ClassVar[int] = 16 + + def execute(self, input_data: TextEmbedderInputSchema, config=None, **kwargs): + text = input_data.query if hasattr(input_data, "query") else input_data["query"] + digest = hashlib.sha256(text.encode("utf-8")).digest() + raw = [(b / 127.5) - 1.0 for b in digest[: self.DIM]] + norm = sum(x * x for x in raw) ** 0.5 or 1.0 + return {"query": text, "embedding": [x / norm for x in raw]} + + +@pytest.fixture +def ltm(): + return LongTermMemoryConfig(backend=InMemoryLongTermMemoryBackend(embedder=_FakeEmbedder())) + + +@pytest.fixture +def llm(): + """Real OpenAI LLM object — never executed in these tests. Constructed + only to satisfy Agent's pydantic validation.""" + return OpenAI( + connection=OpenAIConnection(api_key="test-key"), + model="gpt-4o", + ) + + +def _ltm_config(*, tools=None, enabled=True) -> LongTermMemoryConfig: + backend = InMemoryLongTermMemoryBackend(embedder=_FakeEmbedder()) + kwargs = {"backend": backend, "enabled": enabled} + if tools is not None: + kwargs["tools"] = tools + return LongTermMemoryConfig(**kwargs) + + +def _make_agent(llm, *, ltm=None) -> Agent: + kwargs = {"name": "test", "llm": llm, "tools": []} + if ltm is not None: + kwargs["long_term_memory"] = ltm + return Agent(**kwargs) + + +def _input(user_id=None, session_id=None): + return SimpleNamespace(user_id=user_id, session_id=session_id, input="hi") + + +# --- LongTermMemoryConfig --- + + +def test_config_default_includes_remember_and_recall(): + assert _ltm_config().tools == ("remember", "recall") + + +def test_config_can_restrict_to_read_only(): + assert _ltm_config(tools=("recall",)).tools == ("recall",) + + +def test_config_defaults_to_enabled(): + assert _ltm_config().enabled is True + + +def test_config_model_dump_emits_plain_strings_not_enums(): + """YAML round-trip relies on tool kinds being dumped as their string values, + not as enum members (which yaml.safe_dump cannot represent and which would + round-trip back as the enum *name* — 'REMEMBER' — failing validation).""" + import yaml + + dumped = _ltm_config().model_dump(exclude={"backend"}) + assert dumped["tools"] == ("remember", "recall") + assert all(isinstance(t, str) and not hasattr(t, "value") for t in dumped["tools"]) + yaml.safe_dump(dumped) # must not raise + + +# --- Agent field declarations --- + + +def test_agent_has_long_term_memory_field(): + fields = Agent.model_fields + assert "long_term_memory" in fields + assert fields["long_term_memory"].default is None + + +def test_agent_long_term_memory_defaults_to_none(llm): + agent = _make_agent(llm) + assert agent.long_term_memory is None + + +# --- _build_long_term_memory_tools --- + + +def test_build_returns_default_tools_when_ltm_and_user_id_present(llm, ltm): + agent = _make_agent(llm, ltm=ltm) + tools = agent._build_long_term_memory_tools(_input(user_id="u1")) + assert {t.name for t in tools} == {"remember_fact", "recall_facts"} + + +def test_build_returns_empty_when_no_user_id(llm, ltm): + agent = _make_agent(llm, ltm=ltm) + assert agent._build_long_term_memory_tools(_input(session_id="s1")) == [] + + +def test_build_returns_empty_when_no_long_term_memory(llm): + agent = _make_agent(llm) + assert agent._build_long_term_memory_tools(_input(user_id="u1")) == [] + + +def test_build_respects_config_include(llm): + agent = _make_agent(llm, ltm=_ltm_config(tools=("recall",))) + tools = agent._build_long_term_memory_tools(_input(user_id="u1")) + assert [t.name for t in tools] == ["recall_facts"] + + +def test_build_returns_empty_when_disabled(llm): + agent = _make_agent(llm, ltm=_ltm_config(enabled=False)) + assert agent._build_long_term_memory_tools(_input(user_id="u1")) == [] + + +def test_build_bakes_user_id_into_each_tool(llm, ltm): + agent = _make_agent(llm, ltm=ltm) + tools = agent._build_long_term_memory_tools(_input(user_id="u1")) + for tool in tools: + assert tool.user_id == "u1" + + +def test_build_sets_is_optimized_for_agents_on_each_tool(llm, ltm): + """LTM tools are built per-run, after `init_components` has run, so the agent + must flip `is_optimized_for_agents` itself — otherwise remember/recall would + return raw dicts instead of the friendly status strings the LLM expects.""" + agent = _make_agent(llm, ltm=ltm) + tools = agent._build_long_term_memory_tools(_input(user_id="u1")) + assert tools and all(t.is_optimized_for_agents for t in tools) + + +def test_function_calling_schemas_include_ltm_overlay(llm, ltm): + """In FUNCTION_CALLING mode the per-call LTM tools must appear in the + generated tool schemas, otherwise the LLM can never call remember/recall.""" + from dynamiq.nodes.agents.agent import Agent as ReActAgent + from dynamiq.nodes.agents.base import _run_extra_tools + from dynamiq.nodes.types import InferenceMode + + agent = ReActAgent(name="t", llm=llm, tools=[], long_term_memory=ltm, inference_mode=InferenceMode.FUNCTION_CALLING) + base_tools, _ = agent._effective_inference_schemas() + base_names = {schema["function"]["name"] for schema in (base_tools or [])} + assert "remember_fact" not in base_names # not present without an overlay + + ltm_tools = agent._build_long_term_memory_tools(_input(user_id="u1")) + token = _run_extra_tools.set(ltm_tools) + try: + fc_tools, _ = agent._effective_inference_schemas() + finally: + _run_extra_tools.reset(token) + + names = {schema["function"]["name"] for schema in fc_tools} + assert {"remember_fact", "recall_facts"} <= names + + +def test_xml_prompt_includes_tool_blocks_when_only_ltm_configured(llm, ltm): + """In XML/ReAct mode the system prompt template must reserve tool blocks + when LTM is the only source of tools — otherwise the per-call tool + description has no placeholder and remember/recall stay invisible.""" + from dynamiq.nodes.agents.agent import Agent as ReActAgent + from dynamiq.nodes.types import InferenceMode + + agent = ReActAgent(name="t", llm=llm, tools=[], long_term_memory=ltm, inference_mode=InferenceMode.XML) + tools_block = agent.system_prompt_manager._prompt_blocks.get("tools", "") + assert "{{ tool_description }}" in tools_block + + +def test_xml_prompt_omits_tool_blocks_when_ltm_disabled(llm): + """Disabled LTM must not flip `has_tools` on — the template should still + render the no-tools instructions when nothing else provides tools.""" + from dynamiq.nodes.agents.agent import Agent as ReActAgent + from dynamiq.nodes.types import InferenceMode + + agent = ReActAgent( + name="t", + llm=llm, + tools=[], + long_term_memory=_ltm_config(enabled=False), + inference_mode=InferenceMode.XML, + ) + assert agent.system_prompt_manager._prompt_blocks.get("tools", "") == "" + + +def test_init_components_initializes_ltm_embedder(llm): + """The embedder is a ConnectionNode whose `text_embedder` client is built + during `init_components`; without that, the first recall AttributeErrors + on a None client.""" + init_calls: list = [] + + class _RecordingEmbedder(_FakeEmbedder): + is_postponed_component_init: bool = True + + def init_components(self, connection_manager=None): + init_calls.append(connection_manager) + + ltm_with_postponed = LongTermMemoryConfig(backend=InMemoryLongTermMemoryBackend(embedder=_RecordingEmbedder())) + agent = _make_agent(llm, ltm=ltm_with_postponed) + # Node.__init__ already invokes init_components on construction; clear and + # assert the explicit call also propagates to the embedder. + init_calls.clear() + agent.init_components() + assert len(init_calls) == 1 + + +# --- per-call ContextVar overlay: LTM tools never mutate self.tools --- + + +def _patch_run_agent_capture_runtime_tools(agent, captured): + """Capture what the LLM-facing `tool_by_names` resolution sees mid-run.""" + + def fake_run(*args, **kwargs): + captured.append(set(agent.tool_by_names.keys())) + return "ok" + + return patch.object(agent, "_run_agent", side_effect=fake_run) + + +def test_execute_exposes_ltm_tools_during_run_only(llm, ltm): + """LTM tools must be visible to the tool-resolution properties during the + run, and absent from both `self.tools` and the properties after.""" + agent = _make_agent(llm, ltm=ltm) + original_tools = list(agent.tools) + captured: list[set[str]] = [] + + with _patch_run_agent_capture_runtime_tools(agent, captured): + agent.run_sync(input_data={"input": "hi", "user_id": "u1"}) + + assert {"remember_fact", "recall_facts"} <= captured[0] + assert agent.tools == original_tools + assert {"remember_fact", "recall_facts"}.isdisjoint(agent.tool_by_names.keys()) + + +def test_execute_clears_ltm_overlay_even_when_run_raises(llm, ltm): + agent = _make_agent(llm, ltm=ltm) + original_tools = list(agent.tools) + + with patch.object(agent, "_run_agent", side_effect=RuntimeError("boom")): + agent.run_sync(input_data={"input": "hi", "user_id": "u1"}) + + assert agent.tools == original_tools + assert {"remember_fact", "recall_facts"}.isdisjoint(agent.tool_by_names.keys()) + + +def test_execute_no_ltm_overlay_when_no_user_id(llm, ltm): + agent = _make_agent(llm, ltm=ltm) + captured: list[set[str]] = [] + + with _patch_run_agent_capture_runtime_tools(agent, captured): + agent.run_sync(input_data={"input": "hi"}) + + assert {"remember_fact", "recall_facts"}.isdisjoint(captured[0]) + + +def test_execute_no_ltm_overlay_when_no_long_term_memory(llm): + agent = _make_agent(llm) + captured: list[set[str]] = [] + + with _patch_run_agent_capture_runtime_tools(agent, captured): + agent.run_sync(input_data={"input": "hi", "user_id": "u1"}) + + assert {"remember_fact", "recall_facts"}.isdisjoint(captured[0]) + + +def test_execute_does_not_serialize_concurrent_calls_when_ltm_configured(llm, ltm): + """With the ContextVar overlay, concurrent execute() calls on the same + LTM-configured agent must run truly in parallel — no shared lock.""" + import threading + from concurrent.futures import ThreadPoolExecutor + + agent = _make_agent(llm, ltm=ltm) + barrier = threading.Barrier(2, timeout=5) + + def fake_run(*args, **kwargs): + # If a lock was serialising us, the second thread would never reach + # the barrier and we'd time out — barrier verifies true concurrency. + barrier.wait() + return "ok" + + with patch.object(agent, "_run_agent", side_effect=fake_run): + with ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(agent.run_sync, input_data={"input": "hi", "user_id": "u1"}), + pool.submit(agent.run_sync, input_data={"input": "hi", "user_id": "u2"}), + ] + for f in futures: + f.result(timeout=10) + + +def test_concurrent_calls_isolate_per_user_ltm_tools(llm, ltm): + """Two concurrent execute() calls with different user_ids must each see + only their own LTM tools via the per-task ContextVar overlay.""" + import threading + from concurrent.futures import ThreadPoolExecutor + + agent = _make_agent(llm, ltm=ltm) + snapshots: dict[str, set[str]] = {} + snapshots_lock = threading.Lock() + barrier = threading.Barrier(2, timeout=5) + + def fake_run(*args, **kwargs): + # Wait so both threads are inside _run_agent simultaneously. + barrier.wait() + resolved = agent.tool_by_names + bound_user_ids = {t.user_id for t in resolved.values() if hasattr(t, "user_id")} + assert len(bound_user_ids) == 1, f"cross-user leakage: {bound_user_ids}" + (uid,) = bound_user_ids + with snapshots_lock: + snapshots[uid] = {name for name, t in resolved.items() if hasattr(t, "user_id")} + return "ok" + + with patch.object(agent, "_run_agent", side_effect=fake_run): + with ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(agent.run_sync, input_data={"input": "hi", "user_id": "u1"}), + pool.submit(agent.run_sync, input_data={"input": "hi", "user_id": "u2"}), + ] + for f in futures: + f.result(timeout=10) + + assert set(snapshots.keys()) == {"u1", "u2"} + for tool_names in snapshots.values(): + assert tool_names == {"remember_fact", "recall_facts"} + + +def test_concurrent_no_user_id_call_does_not_see_other_users_ltm_tools(llm, ltm): + """A concurrent no-user_id execute must not observe another call's + user-scoped tools — ContextVar isolation guarantees this without a lock.""" + import threading + from concurrent.futures import ThreadPoolExecutor + + agent = _make_agent(llm, ltm=ltm) + snapshots: dict[str, set] = {} + snapshots_lock = threading.Lock() + barrier = threading.Barrier(2, timeout=5) + + def fake_run(*args, **kwargs): + barrier.wait() + resolved = agent.tool_by_names + bound = {getattr(t, "user_id", None) for t in resolved.values() if hasattr(t, "user_id")} + with snapshots_lock: + key = next(iter(bound), "none") + snapshots[key] = bound + return "ok" + + with patch.object(agent, "_run_agent", side_effect=fake_run): + with ThreadPoolExecutor(max_workers=2) as pool: + futures = [ + pool.submit(agent.run_sync, input_data={"input": "hi", "user_id": "u1"}), + pool.submit(agent.run_sync, input_data={"input": "hi"}), + ] + for f in futures: + f.result(timeout=10) + + assert snapshots.get("u1") == {"u1"} + assert snapshots.get("none", set()) == set()