diff --git a/examples/encrypted-memory-mdb.py b/examples/encrypted-memory-mdb.py new file mode 100644 index 0000000..43d6740 --- /dev/null +++ b/examples/encrypted-memory-mdb.py @@ -0,0 +1,168 @@ +from memorizz.memory_provider.mongodb.provider import MongoDBConfig, MongoDBProvider +import os + +from bson import ObjectId +from pymongo import MongoClient +from pymongo.encryption import ClientEncryption, Algorithm +from bson.binary import Binary, STANDARD +from bson.codec_options import CodecOptions + +# fresh key every run: local_master_key_string = os.urandom(96) + +# --- Manage the Master Key persistently --- +# Define a path for the master key file +MASTER_KEY_FILE = "master_key.bin" + +# Check if the master key file exists +if os.path.exists(MASTER_KEY_FILE): + # Load the existing master key + with open(MASTER_KEY_FILE, "rb") as f: + local_master_key_string = f.read() + print("Loaded existing master key.") +else: + # Generate a new master key and save it to a file + local_master_key_string = os.urandom(96) + with open(MASTER_KEY_FILE, "wb") as f: + f.write(local_master_key_string) + print("Generated and saved new master key.") + +os.environ["VOYAGE_API_KEY"] = "" +os.environ["OPENAI_API_KEY"] = "" + +# --- Configure KMS and Key Vault --- +# The master key is used to encrypt the data keys. For this local demo, +# we use a randomly generated 96-byte key. In a production environment, +# this key would be managed by a secure Key Management Service (KMS) +# like AWS KMS, Azure Key Vault, or GCP KMS. +# A separate MongoClient is used for the key vault for security best practices. +# Configure the local KMS provider +kms_providers = {"local": {"key": local_master_key_string}} + +# Define the namespace for the key vault collection, which will store the data keys. +key_vault_namespace = "encryption.__pymongoTestKeyVault" + +# Initialize a separate client for the key vault. This is a security best practice. +key_vault_client = MongoClient("mongodb://localhost:27017/?directConnection=true") + +# The MongoDBConfig object encapsulates all necessary settings: URI, database name, +# embedding provider details for vector search, and the critical encryption configuration. +# The `encryption_config` dictionary tells our provider: +# - Where to find the master key (kms_providers) +# - Where to store the data keys (key_vault_namespace) +# - Which fields to encrypt for specific collections (`field_mappings`) +# The algorithms are specified per field: `Random` encryption is used here because +# we are encrypting a complex BSON object (`function` and `type`), which is +# not queryable and therefore does not require `Deterministic` encryption. +mongodb_config = MongoDBConfig( + uri="mongodb://localhost:27017/?retryWrites=true&w=majority&directConnection=true", + db_name="testing_memorizz", + embedding_provider="voyageai", + embedding_config={ + "embedding_type": "text", + "model": "voyage-3.5", + "output_dimension": 1024, + "input_type": "query" + }, + encryption_config={ + "kms_providers": kms_providers, + "key_vault_namespace": key_vault_namespace, + "key_vault_client": key_vault_client, + "field_mappings": { + "toolbox": { + "encrypted_fields": { + "function": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + "type": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random + } + } + } + } +) + +# Create a memory provider +memory_provider = MongoDBProvider(mongodb_config) + +from functools import lru_cache +from yahooquery import Ticker +import time + +@lru_cache(maxsize=128) +def _fetch_price(symbol: str) -> float: + """ + Internal helper to fetch the latest market price via yahooquery. + Caching helps avoid repeated hits for the same symbol. + """ + ticker = Ticker(symbol) + # This returns a dict keyed by symbol: + info = ticker.price or {} + # regularMarketPrice holds the current trading price + price = info.get(symbol.upper(), {}).get("regularMarketPrice") + if price is None: + raise ValueError(f"No price data for '{symbol}'") + return price + +def get_stock_price( + symbol: str, + currency: str = "USD", + retry: int = 3, + backoff: float = 0.5 +) -> str: + """ + Get the current stock price for a given symbol using yahooquery, + with simple retry/backoff to handle occasional rate-limits. + + Parameters + ---------- + symbol : str + Stock ticker, e.g. "AAPL" + currency : str, optional + Currency code (Currently informational only; yahooquery returns native) + retry : int, optional + Number of retries on failure (default: 3) + backoff : float, optional + Backoff factor in seconds between retries (default: 0.5s) + + Returns + ------- + str + e.g. "The current price of AAPL is 172.34 USD." + """ + symbol = symbol.upper() + last_err = None + for attempt in range(1, retry + 1): + try: + price = _fetch_price(symbol) + return f"The current price of {symbol} is {price:.2f} {currency.upper()}." + except Exception as e: + last_err = e + # simple backoff + time.sleep(backoff * attempt) + # if we get here, all retries failed + raise RuntimeError(f"Failed to fetch price for '{symbol}' after {retry} attempts: {last_err}") + +import requests + +def get_weather(latitude, longitude): + response = requests.get(f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m") + data = response.json() + return data['current']['temperature_2m'] + +from memorizz.long_term_memory.procedural.toolbox import Toolbox + +from memorizz.embeddings import configure_embeddings, get_embedding +configure_embeddings( + provider="voyageai", + config={ + "embedding_type": "text", + "model": "voyage-3.5", + "output_dimension": 1024, + "input_type": "query" + } +) + +toolbox = Toolbox( + memory_provider=memory_provider, +) +# Now the tools are registered in the memory provider within the toolbox +toolbox.register_tool(get_weather) +toolbox.register_tool(get_stock_price) +print(toolbox.list_tools()) \ No newline at end of file diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index c2262e0..d2b50fa 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -2,6 +2,11 @@ import logging from bson import ObjectId from pymongo import MongoClient +from pymongo.errors import PyMongoError +from pymongo.encryption import ClientEncryption, Algorithm +from bson.binary import Binary, STANDARD +from bson.codec_options import CodecOptions + from ..base import MemoryProvider from dataclasses import dataclass from ...enums.memory_type import MemoryType @@ -12,13 +17,15 @@ from pymongo.operations import SearchIndexModel from ...embeddings import get_embedding, get_embedding_dimensions +import uuid + logger = logging.getLogger(__name__) @dataclass class MongoDBConfig(): """Configuration for the MongoDB provider.""" - def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: bool = False, embedding_provider = None, embedding_config: Dict[str, Any] = None): + def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: bool = False, embedding_provider = None, embedding_config: Dict[str, Any] = None, encryption_config: Dict[str, Any] = None): """ Initialize the MongoDB provider with configuration settings. @@ -40,12 +47,16 @@ def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: boo embedding_config : Dict[str, Any], optional Configuration for the embedding provider. Only used when embedding_provider is a string. Example: {"model": "text-embedding-3-small", "dimensions": 512} + encryption_config : Dict[str, Any], optional + Configuration for Client-Side Field Level Encryption (CSFLE) settings. """ self.uri = uri self.db_name = db_name self.lazy_vector_indexes = lazy_vector_indexes self.embedding_provider = embedding_provider self.embedding_config = embedding_config or {} + # The encryption configuration is stored, but no actions are performed. + self.encryption_config = encryption_config or {} class MongoDBProvider(MemoryProvider): @@ -54,19 +65,75 @@ class MongoDBProvider(MemoryProvider): def __init__(self, config: MongoDBConfig): """ Initialize the MongoDB provider with configuration settings. - + Parameters: ----------- config : MongoDBConfig - Configuration dictionary containing: - - 'uri': MongoDB URI - - 'db_name': Database name - - 'lazy_vector_indexes': Whether to defer vector index creation - - 'embedding_provider': Optional explicit embedding provider + Configuration object containing database, embedding, and encryption settings. """ self.config = config self.client = MongoClient(config.uri) self.db = self.client[config.db_name] + + # --- Initialize CSFLE with smart checks and dynamic keymap creation --- + self._client_encryption = None + self.keymap = {} # Initialize the keymap as an empty dictionary + + if self.config.encryption_config: + try: + # Extract key vault details from the config + kms_providers = self.config.encryption_config.get("kms_providers") + key_vault_namespace = self.config.encryption_config.get("key_vault_namespace") + key_vault_client = self.config.encryption_config.get("key_vault_client") + + """ + A separate keyvault client is a good idea for encryption because it enforces + a crucial security practice: the separation of data and keys. + """ + + if not all([kms_providers, key_vault_namespace, key_vault_client]): + raise ValueError("Incomplete encryption_config provided.") + + # Check if key vault collection exists before creating it + key_vault_db, key_vault_coll = key_vault_namespace.split(".", 1) + if key_vault_coll not in key_vault_client[key_vault_db].list_collection_names(): + print(f"Key vault collection '{key_vault_coll}' not found. Creating it now.") + key_vault_client[key_vault_db].create_collection(key_vault_coll) + else: + print(f"Key vault collection '{key_vault_coll}' already exists.") + + # Initialize the ClientEncryption object + self._client_encryption = ClientEncryption( + kms_providers=kms_providers, + key_vault_namespace=key_vault_namespace, + key_vault_client=key_vault_client, + codec_options=CodecOptions(uuid_representation=STANDARD), + ) + + # Loop through field mappings to get or create data keys + field_mappings = self.config.encryption_config.get("field_mappings", {}) + for collection_name, mapping in field_mappings.items(): + # Check if a data key with this alt name already exists + data_key = self._client_encryption.get_key_by_alt_name(collection_name) + if data_key is None: + # If not, create a new one using a UUID as the alt name for uniqueness + alt_name = str(uuid.uuid4()) + data_key_id = self._client_encryption.create_data_key("local", key_alt_names=[alt_name]) + print(f"Generated new data key with ID: {data_key_id} for collection '{collection_name}'") + self.keymap[collection_name] = alt_name + else: + print(f"Using existing data key for collection '{collection_name}'") + # You would need to retrieve the alt name from the data_key document + # as it may not match the collection_name. For simplicity, we assume + # it does or use the first one available. + self.keymap[collection_name] = data_key['keyAltNames'][0] + + except Exception as e: + logger.error(f"Error initializing CSFLE: {e}") + self._client_encryption = None + self.keymap = {} + + # --- EXISTING CODE --- self.persona_collection = self.db[MemoryType.PERSONAS.value] self.toolbox_collection = self.db[MemoryType.TOOLBOX.value] self.short_term_memory_collection = self.db[MemoryType.SHORT_TERM_MEMORY.value] @@ -79,7 +146,7 @@ def __init__(self, config: MongoDBConfig): # Track which vector indexes have been created self._vector_indexes_created = set() - + # Process embedding provider configuration self._embedding_provider = self._setup_embedding_provider(config) @@ -96,6 +163,7 @@ def __init__(self, config: MongoDBConfig): # Set lazy mode if immediate creation fails self.config.lazy_vector_indexes = True + def _setup_embedding_provider(self, config: MongoDBConfig): """ Setup the embedding provider based on configuration. @@ -248,24 +316,55 @@ def _create_vector_indexes_for_memory_stores(self) -> None: memory_store=memory_store_present, ) - def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: + def _is_encrypted_field(self, collection_name: str, field_name: str) -> Optional[str]: """ - Store data in MongoDB using only _id field as primary key. + Checks if a field should be encrypted based on the encryption config. + + Parameters: + ----------- + collection_name : str + The name of the collection being checked. + field_name : str + The name of the field to check for encryption. + + Returns: + -------- + Optional[str] + The encryption algorithm as a string if the field should be encrypted, otherwise None. + """ + if not self.config.encryption_config: + return None + field_mappings = self.config.encryption_config.get("field_mappings", {}) + collection_config = field_mappings.get(collection_name) + + if not collection_config: + return None + + encrypted_fields = collection_config.get("encrypted_fields", {}) + + # This now returns the algorithm type from the config + return encrypted_fields.get(field_name) + + def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: + """ + Store data in MongoDB, encrypting sensitive fields if CSFLE is configured. + Parameters: ----------- data : Dict[str, Any] The document to be stored. memory_store_type : MemoryType The type of memory store (e.g., "persona", "toolbox", etc.) - + Returns: -------- str The ID of the inserted/updated document (MongoDB _id). """ - # Get the appropriate collection based on memory type collection = None + collection_name = memory_store_type.value + if memory_store_type == MemoryType.PERSONAS: collection = self.persona_collection elif memory_store_type == MemoryType.TOOLBOX: @@ -286,28 +385,41 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: if collection is None: raise ValueError(f"Invalid memory store type: {memory_store_type}") - # Clean data by removing custom ID fields - only use MongoDB _id - # Note: conversation_id is preserved for CONVERSATION_MEMORY as it serves a functional purpose data_copy = data.copy() - # Remove custom ID fields since we only want to use _id + # Clean custom ID fields custom_id_fields = [ - "persona_id", "tool_id", "workflow_id", "short_term_memory_id", + "persona_id", "tool_id", "workflow_id", "short_term_memory_id", "agent_id" ] - - # Don't remove conversation_id for conversation memory if memory_store_type != MemoryType.CONVERSATION_MEMORY: custom_id_fields.append("conversation_id") - - # Don't remove long_term_memory_id for long-term memory as it's needed for knowledge linking if memory_store_type != MemoryType.LONG_TERM_MEMORY: custom_id_fields.append("long_term_memory_id") for field in custom_id_fields: data_copy.pop(field, None) - - # If document has MongoDB _id, update it + + # Encryption Logic: check if CSFLE is configured and a keymap exists + if self._client_encryption and self.keymap: + for field_name, value in list(data_copy.items()): + algorithm = self._is_encrypted_field(collection_name, field_name) + + # If encryption is required for this field... + if algorithm: + key_alt_name = self.keymap.get(collection_name) + if not key_alt_name: + raise ValueError(f"Encryption key not found for collection: {collection_name}") + + # Explicitly encrypt the field's value + encrypted_value = self._client_encryption.encrypt( + value, + Algorithm(algorithm), + key_alt_name=key_alt_name + ) + data_copy[field_name] = encrypted_value + + # If document has MongoDB _id, update it (upsert) if "_id" in data_copy: result = collection.update_one( {"_id": data_copy["_id"]}, @@ -320,9 +432,10 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: result = collection.insert_one(data_copy) return str(result.inserted_id) - def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType, limit: int = 1, include_embedding: bool = False) -> Optional[Dict[str, Any]]: + + def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType, limit: int = 1, include_embedding: bool = False) -> Optional[List[Dict[str, Any]]]: """ - Retrieve a document from MongoDB. + Retrieve documents from MongoDB, decrypting fields if CSFLE is configured. Parameters: ----------- @@ -335,28 +448,58 @@ def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType Returns: -------- - Optional[Dict[str, Any]] - The retrieved document, or None if not found. + Optional[List[Dict[str, Any]]] + A list of retrieved and decrypted documents, or None if not found. """ # Define projection to exclude embeddings by default projection = {} if include_embedding else {"embedding": 0} + documents = [] + + # Dispatch logic to call the correct specialized retrieval method if memory_store_type == MemoryType.PERSONAS: - return self.retrieve_persona_by_query(query, limit=limit) + documents = self.retrieve_persona_by_query(query, limit=limit) elif memory_store_type == MemoryType.TOOLBOX: - return self.retrieve_toolbox_item(query, limit) + documents = self.retrieve_toolbox_item(query, limit=limit) elif memory_store_type == MemoryType.WORKFLOW_MEMORY: - return self.retrieve_workflow_by_query(query, limit) + documents = self.retrieve_workflow_by_query(query, limit=limit) elif memory_store_type == MemoryType.SHORT_TERM_MEMORY: - return self.short_term_memory_collection.find(query, projection).limit(limit) + documents = list(self.short_term_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.LONG_TERM_MEMORY: - return self.long_term_memory_collection.find(query, projection).limit(limit) + documents = list(self.long_term_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.CONVERSATION_MEMORY: - return self.conversation_memory_collection.find(query, projection).limit(limit) + documents = list(self.conversation_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.SUMMARIES: - return self.retrieve_summaries_by_query(query, limit) - + documents = self.retrieve_summaries_by_query(query, limit=limit) + + # Return early if no documents are found by the dispatcher + if not documents: + return None + + # --- New Decryption Logic --- + # Only attempt decryption if CSFLE is enabled + if self._client_encryption: + decrypted_documents = [] + for doc in documents: + decrypted_doc = doc.copy() + for field_name, value in doc.items(): + # Encrypted fields are stored as BSON Binary subtype 6 + if isinstance(value, Binary): + try: + # Attempt to decrypt the value + decrypted_value = self._client_encryption.decrypt(value) + decrypted_doc[field_name] = decrypted_value + except PyMongoError as e: + logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") + # Keep the encrypted value if decryption fails to avoid crashing + decrypted_doc[field_name] = value + decrypted_documents.append(decrypted_doc) + return decrypted_documents + + # If CSFLE is not configured, just return the raw documents + return documents + def retrieve_by_id(self, id: str, memory_store_type: MemoryType) -> Optional[Dict[str, Any]]: """ Retrieve a document from MongoDB by _id. @@ -828,7 +971,7 @@ def delete_all(self, memory_store_type: MemoryType) -> bool: def list_all(self, memory_store_type: MemoryType, include_embedding: bool = False) -> List[Dict[str, Any]]: """ - List all documents within a memory store type in MongoDB. + List all documents within a memory store, decrypting fields if CSFLE is configured. Parameters: ----------- @@ -845,26 +988,53 @@ def list_all(self, memory_store_type: MemoryType, include_embedding: bool = Fals # Define projection to exclude embeddings by default projection = {} if include_embedding else {"embedding": 0} + collection = None if memory_store_type == MemoryType.PERSONAS: - return list(self.persona_collection.find({}, projection)) + collection = self.persona_collection elif memory_store_type == MemoryType.TOOLBOX: - return list(self.toolbox_collection.find({}, projection)) + collection = self.toolbox_collection elif memory_store_type == MemoryType.SHORT_TERM_MEMORY: - return list(self.short_term_memory_collection.find({}, projection)) + collection = self.short_term_memory_collection elif memory_store_type == MemoryType.LONG_TERM_MEMORY: - return list(self.long_term_memory_collection.find({}, projection)) + collection = self.long_term_memory_collection elif memory_store_type == MemoryType.CONVERSATION_MEMORY: - return list(self.conversation_memory_collection.find({}, projection)) + collection = self.conversation_memory_collection elif memory_store_type == MemoryType.WORKFLOW_MEMORY: - return list(self.workflow_memory_collection.find({}, projection)) + collection = self.workflow_memory_collection elif memory_store_type == MemoryType.SHARED_MEMORY: - return list(self.shared_memory_collection.find({}, projection)) + collection = self.shared_memory_collection elif memory_store_type == MemoryType.SUMMARIES: - return list(self.summaries_collection.find({}, projection)) + collection = self.summaries_collection else: logger.warning(f"Unsupported memory store type for list_all: {memory_store_type}") return [] + # Retrieve all documents from the collection + documents = list(collection.find({}, projection)) + + # --- Decryption Logic --- + # Only attempt to decrypt if CSFLE is configured and initialized + if self._client_encryption: + decrypted_documents = [] + for doc in documents: + decrypted_doc = doc.copy() + for field_name, value in doc.items(): + # Check if the value is a BSON Binary object (which indicates encryption) + if isinstance(value, Binary): + try: + # Attempt to decrypt the field + decrypted_value = self._client_encryption.decrypt(value) + decrypted_doc[field_name] = decrypted_value + except PyMongoError as e: + logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") + # Keep the encrypted value if decryption fails + decrypted_doc[field_name] = value + decrypted_documents.append(decrypted_doc) + return decrypted_documents + + # If CSFLE is not configured, just return the raw documents + return documents + def update_by_id(self, id: str, data: Dict[str, Any], memory_store_type: MemoryType) -> bool: """ Update a document in a memory store type in MongoDB by _id. @@ -1427,7 +1597,7 @@ def delete_memagent(self, agent_id: str, cascade: bool = False) -> bool: True if deletion was successful, False otherwise. """ if cascade: - # Retrieve the memagent + # Retrieve the memagent memagent = self.retrieve_memagent(agent_id) if memagent is None: @@ -1481,7 +1651,6 @@ def _delete_memory_units_by_memory_id(self, memory_id: str, memory_type: MemoryT self.toolbox_collection.delete_many({"memory_id": memory_id}) elif memory_type == MemoryType.MEMAGENT: self.memagent_collection.delete_many({"memory_id": memory_id}) - def _setup_vector_search_index(self, collection, index_name="vector_index", memory_store: bool = False): """