From 5c0fec15d8e7309bf4b4ee6fc7f7b4e0ddd7c2b6 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 14:56:31 -0400 Subject: [PATCH 1/6] CSFLE support; babysteps --- examples/encrypted-memory-mdb.py | 156 +++++++++++++++++ .../memory_provider/mongodb/provider.py | 157 +++++++++++++----- 2 files changed, 274 insertions(+), 39 deletions(-) create mode 100644 examples/encrypted-memory-mdb.py diff --git a/examples/encrypted-memory-mdb.py b/examples/encrypted-memory-mdb.py new file mode 100644 index 0000000..7951b8f --- /dev/null +++ b/examples/encrypted-memory-mdb.py @@ -0,0 +1,156 @@ +from memorizz.memory_provider.mongodb.provider import MongoDBConfig, MongoDBProvider +import os + +from bson import ObjectId +from pymongo import MongoClient +from pymongo.encryption import ClientEncryption, Algorithm +from bson.binary import Binary, STANDARD +from bson.codec_options import CodecOptions + +local_master_key_string = os.urandom(96) +os.environ["VOYAGE_API_KEY"] = "" + +# --- Step 2: Configure KMS and Key Vault --- +# The master key is used to encrypt the data keys. For this local demo, +# we use a randomly generated 96-byte key. In a production environment, +# this key would be managed by a secure Key Management Service (KMS) +# like AWS KMS, Azure Key Vault, or GCP KMS. +# A separate MongoClient is used for the key vault for security best practices. +# Configure the local KMS provider +kms_providers = {"local": {"key": local_master_key_string}} + +# Define the namespace for the key vault collection, which will store the data keys. +key_vault_namespace = "encryption.__pymongoTestKeyVault" + +# Initialize a separate client for the key vault. This is a security best practice. +key_vault_client = MongoClient("mongodb://localhost:27017/?directConnection=true") + +# The MongoDBConfig object encapsulates all necessary settings: URI, database name, +# embedding provider details for vector search, and the critical encryption configuration. +# The `encryption_config` dictionary tells our provider: +# - Where to find the master key (kms_providers) +# - Where to store the data keys (key_vault_namespace) +# - Which fields to encrypt for specific collections (`field_mappings`) +# The algorithms are specified per field: `Random` encryption is used here because +# we are encrypting a complex BSON object (`function` and `type`), which is +# not queryable and therefore does not require `Deterministic` encryption. +mongodb_config = MongoDBConfig( + uri="mongodb://localhost:27017/?retryWrites=true&w=majority&directConnection=true", + db_name="testing_memorizz", + embedding_provider="voyageai", + embedding_config={ + "embedding_type": "text", + "model": "voyage-3.5", + "output_dimension": 1024, + "input_type": "query" + }, + encryption_config={ + "kms_providers": kms_providers, + "key_vault_namespace": key_vault_namespace, + "key_vault_client": key_vault_client, + "field_mappings": { + "toolbox": { + "encrypted_fields": { + "function": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + "type": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random + } + }, + "personas": { + "encrypted_fields": { + "background": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random + } + } + } + } +) + +# Create a memory provider +memory_provider = MongoDBProvider(mongodb_config) + +from functools import lru_cache +from yahooquery import Ticker +import time + +@lru_cache(maxsize=128) +def _fetch_price(symbol: str) -> float: + """ + Internal helper to fetch the latest market price via yahooquery. + Caching helps avoid repeated hits for the same symbol. + """ + ticker = Ticker(symbol) + # This returns a dict keyed by symbol: + info = ticker.price or {} + # regularMarketPrice holds the current trading price + price = info.get(symbol.upper(), {}).get("regularMarketPrice") + if price is None: + raise ValueError(f"No price data for '{symbol}'") + return price + +def get_stock_price( + symbol: str, + currency: str = "USD", + retry: int = 3, + backoff: float = 0.5 +) -> str: + """ + Get the current stock price for a given symbol using yahooquery, + with simple retry/backoff to handle occasional rate-limits. + + Parameters + ---------- + symbol : str + Stock ticker, e.g. "AAPL" + currency : str, optional + Currency code (Currently informational only; yahooquery returns native) + retry : int, optional + Number of retries on failure (default: 3) + backoff : float, optional + Backoff factor in seconds between retries (default: 0.5s) + + Returns + ------- + str + e.g. "The current price of AAPL is 172.34 USD." + """ + symbol = symbol.upper() + last_err = None + for attempt in range(1, retry + 1): + try: + price = _fetch_price(symbol) + return f"The current price of {symbol} is {price:.2f} {currency.upper()}." + except Exception as e: + last_err = e + # simple backoff + time.sleep(backoff * attempt) + # if we get here, all retries failed + raise RuntimeError(f"Failed to fetch price for '{symbol}' after {retry} attempts: {last_err}") + +import requests + +def get_weather(latitude, longitude): + response = requests.get(f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m") + data = response.json() + return data['current']['temperature_2m'] + +from memorizz.long_term_memory.procedural.toolbox import Toolbox + +from memorizz.embeddings import configure_embeddings, get_embedding +configure_embeddings( + provider="voyageai", + config={ + "embedding_type": "text", + "model": "voyage-3.5", + "output_dimension": 1024, + "input_type": "query" + } +) + +os.environ["OPENAI_API_KEY"] = "" + +toolbox = Toolbox( + memory_provider=memory_provider, +) +# Now the tools are registered in the memory provider within the toolbox +toolbox.register_tool(get_weather) +toolbox.register_tool(get_stock_price) +print(toolbox.list_tools()) \ No newline at end of file diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index c2262e0..7e577db 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -2,6 +2,10 @@ import logging from bson import ObjectId from pymongo import MongoClient +from pymongo.encryption import ClientEncryption, Algorithm +from bson.binary import Binary, STANDARD +from bson.codec_options import CodecOptions + from ..base import MemoryProvider from dataclasses import dataclass from ...enums.memory_type import MemoryType @@ -12,13 +16,15 @@ from pymongo.operations import SearchIndexModel from ...embeddings import get_embedding, get_embedding_dimensions +import uuid + logger = logging.getLogger(__name__) @dataclass class MongoDBConfig(): """Configuration for the MongoDB provider.""" - def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: bool = False, embedding_provider = None, embedding_config: Dict[str, Any] = None): + def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: bool = False, embedding_provider = None, embedding_config: Dict[str, Any] = None, encryption_config: Dict[str, Any] = None): """ Initialize the MongoDB provider with configuration settings. @@ -33,42 +39,95 @@ def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: boo If False, vector indexes are created immediately during initialization (requires embedding configuration). Default: False (maintains backward compatibility) embedding_provider : str or EmbeddingManager, optional - Embedding provider to use. Can be: - - EmbeddingManager instance (explicit injection) - - String provider name ("openai", "ollama", "voyageai") - - None (uses global embedding configuration) + Embedding provider to use. embedding_config : Dict[str, Any], optional - Configuration for the embedding provider. Only used when embedding_provider is a string. - Example: {"model": "text-embedding-3-small", "dimensions": 512} + Configuration for the embedding provider. + encryption_config : Dict[str, Any], optional + Configuration for encryption settings. """ self.uri = uri self.db_name = db_name self.lazy_vector_indexes = lazy_vector_indexes self.embedding_provider = embedding_provider self.embedding_config = embedding_config or {} + # The encryption configuration is stored, but no actions are performed. + self.encryption_config = encryption_config or {} class MongoDBProvider(MemoryProvider): """MongoDB implementation of the MemoryProvider interface.""" - def __init__(self, config: MongoDBConfig): + def __init__(self, config): """ Initialize the MongoDB provider with configuration settings. - + Parameters: ----------- config : MongoDBConfig - Configuration dictionary containing: - - 'uri': MongoDB URI - - 'db_name': Database name - - 'lazy_vector_indexes': Whether to defer vector index creation - - 'embedding_provider': Optional explicit embedding provider + Configuration object containing database, embedding, and encryption settings. """ self.config = config self.client = MongoClient(config.uri) self.db = self.client[config.db_name] + + # --- NEW: Initialize CSFLE with smart checks and dynamic keymap creation --- + self._client_encryption = None + self.keymap = {} # Initialize the keymap as an empty dictionary + + if self.config.encryption_config: + try: + # Extract key vault details from the config + kms_providers = self.config.encryption_config.get("kms_providers") + key_vault_namespace = self.config.encryption_config.get("key_vault_namespace") + key_vault_client = self.config.encryption_config.get("key_vault_client") + + if not all([kms_providers, key_vault_namespace, key_vault_client]): + raise ValueError("Incomplete encryption_config provided.") + + # Check if key vault collection exists before creating it + key_vault_db, key_vault_coll = key_vault_namespace.split(".", 1) + if key_vault_coll not in key_vault_client[key_vault_db].list_collection_names(): + print(f"Key vault collection '{key_vault_coll}' not found. Creating it now.") + key_vault_client[key_vault_db].create_collection(key_vault_coll) + else: + print(f"Key vault collection '{key_vault_coll}' already exists.") + + # Initialize the ClientEncryption object + self._client_encryption = ClientEncryption( + kms_providers=kms_providers, + key_vault_namespace=key_vault_namespace, + key_vault_client=key_vault_client, + codec_options=CodecOptions(uuid_representation=STANDARD), + ) + + # Loop through field mappings to get or create data keys + field_mappings = self.config.encryption_config.get("field_mappings", {}) + for collection_name, mapping in field_mappings.items(): + # Check if a data key with this alt name already exists + data_key = self._client_encryption.get_key_by_alt_name(collection_name) + if data_key is None: + # If not, create a new one using a UUID as the alt name for uniqueness + alt_name = str(uuid.uuid4()) + data_key_id = self._client_encryption.create_data_key("local", key_alt_names=[alt_name]) + print(f"Generated new data key with ID: {data_key_id} for collection '{collection_name}'") + self.keymap[collection_name] = alt_name + else: + print(f"Using existing data key for collection '{collection_name}'") + # You would need to retrieve the alt name from the data_key document + # as it may not match the collection_name. For simplicity, we assume + # it does or use the first one available. + self.keymap[collection_name] = data_key['keyAltNames'][0] + + except Exception as e: + logger.error(f"Error initializing CSFLE: {e}") + self._client_encryption = None + self.keymap = {} + + # --- EXISTING CODE --- self.persona_collection = self.db[MemoryType.PERSONAS.value] self.toolbox_collection = self.db[MemoryType.TOOLBOX.value] + print("TOOLBOX"+MemoryType.TOOLBOX.value) + print("&^^^^^^^^") self.short_term_memory_collection = self.db[MemoryType.SHORT_TERM_MEMORY.value] self.long_term_memory_collection = self.db[MemoryType.LONG_TERM_MEMORY.value] self.conversation_memory_collection = self.db[MemoryType.CONVERSATION_MEMORY.value] @@ -79,7 +138,7 @@ def __init__(self, config: MongoDBConfig): # Track which vector indexes have been created self._vector_indexes_created = set() - + # Process embedding provider configuration self._embedding_provider = self._setup_embedding_provider(config) @@ -247,25 +306,32 @@ def _create_vector_indexes_for_memory_stores(self) -> None: index_name="vector_index", memory_store=memory_store_present, ) - - def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: + def _is_encrypted_field(self, collection_name: str, field_name: str) -> bool: + """ + Checks if a field should be encrypted based on the encryption config. """ - Store data in MongoDB using only _id field as primary key. - Parameters: - ----------- - data : Dict[str, Any] - The document to be stored. - memory_store_type : MemoryType - The type of memory store (e.g., "persona", "toolbox", etc.) + if not self.config.encryption_config: + return None - Returns: - -------- - str - The ID of the inserted/updated document (MongoDB _id). + field_mappings = self.config.encryption_config.get("field_mappings", {}) + collection_config = field_mappings.get(collection_name) + + if not collection_config: + return None + + encrypted_fields = collection_config.get("encrypted_fields", {}) + + # This now returns the algorithm type from the config + return encrypted_fields.get(field_name) + + def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: + """ + Store data in MongoDB, encrypting sensitive fields. """ - # Get the appropriate collection based on memory type collection = None + collection_name = memory_store_type.value + if memory_store_type == MemoryType.PERSONAS: collection = self.persona_collection elif memory_store_type == MemoryType.TOOLBOX: @@ -286,28 +352,41 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: if collection is None: raise ValueError(f"Invalid memory store type: {memory_store_type}") - # Clean data by removing custom ID fields - only use MongoDB _id - # Note: conversation_id is preserved for CONVERSATION_MEMORY as it serves a functional purpose data_copy = data.copy() - # Remove custom ID fields since we only want to use _id + # Clean custom ID fields custom_id_fields = [ "persona_id", "tool_id", "workflow_id", "short_term_memory_id", "agent_id" ] - - # Don't remove conversation_id for conversation memory if memory_store_type != MemoryType.CONVERSATION_MEMORY: custom_id_fields.append("conversation_id") - - # Don't remove long_term_memory_id for long-term memory as it's needed for knowledge linking if memory_store_type != MemoryType.LONG_TERM_MEMORY: custom_id_fields.append("long_term_memory_id") for field in custom_id_fields: data_copy.pop(field, None) - - # If document has MongoDB _id, update it + + # Encryption Logic: check if CSFLE is configured + if self._client_encryption and self.keymap: + for field_name, value in list(data_copy.items()): + algorithm = self._is_encrypted_field(collection_name, field_name) + + # If encryption is required for this field... + if algorithm: + key_alt_name = self.keymap.get(collection_name) + if not key_alt_name: + raise ValueError(f"Encryption key not found for collection: {collection_name}") + + # Explicitly encrypt the field's value + encrypted_value = self._client_encryption.encrypt( + value, + algorithm, + key_alt_name=key_alt_name + ) + data_copy[field_name] = encrypted_value + + # Update or insert logic if "_id" in data_copy: result = collection.update_one( {"_id": data_copy["_id"]}, @@ -316,10 +395,10 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: ) return str(data_copy["_id"]) else: - # For new documents, let MongoDB generate _id automatically result = collection.insert_one(data_copy) return str(result.inserted_id) + def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType, limit: int = 1, include_embedding: bool = False) -> Optional[Dict[str, Any]]: """ Retrieve a document from MongoDB. From d4cc8ab6b34bf3895df7f9703423298ed43799f7 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 15:49:51 -0400 Subject: [PATCH 2/6] retrieval decryption when possible --- .../memory_provider/mongodb/provider.py | 87 +++++++++++++++---- 1 file changed, 71 insertions(+), 16 deletions(-) diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index 7e577db..78154a8 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -421,20 +421,48 @@ def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType # Define projection to exclude embeddings by default projection = {} if include_embedding else {"embedding": 0} + documents = [] + + # Dispatch logic to call the correct specialized retrieval method if memory_store_type == MemoryType.PERSONAS: - return self.retrieve_persona_by_query(query, limit=limit) + documents = self.retrieve_persona_by_query(query, limit=limit) elif memory_store_type == MemoryType.TOOLBOX: - return self.retrieve_toolbox_item(query, limit) + documents = self.retrieve_toolbox_item(query, limit=limit) elif memory_store_type == MemoryType.WORKFLOW_MEMORY: - return self.retrieve_workflow_by_query(query, limit) + documents = self.retrieve_workflow_by_query(query, limit=limit) elif memory_store_type == MemoryType.SHORT_TERM_MEMORY: - return self.short_term_memory_collection.find(query, projection).limit(limit) + documents = list(self.short_term_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.LONG_TERM_MEMORY: - return self.long_term_memory_collection.find(query, projection).limit(limit) + documents = list(self.long_term_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.CONVERSATION_MEMORY: - return self.conversation_memory_collection.find(query, projection).limit(limit) + documents = list(self.conversation_memory_collection.find(query, projection).limit(limit)) elif memory_store_type == MemoryType.SUMMARIES: - return self.retrieve_summaries_by_query(query, limit) + documents = self.retrieve_summaries_by_query(query, limit=limit) + + # Return early if no documents are found by the dispatcher + if not documents: + return None + + # --- New Decryption Logic --- + if self._client_encryption: + decrypted_documents = [] + for doc in documents: + decrypted_doc = doc.copy() + for field_name, value in doc.items(): + # Check if the value is a Binary object (which indicates encryption) + if isinstance(value, Binary): + try: + decrypted_value = self._client_encryption.decrypt(value) + decrypted_doc[field_name] = decrypted_value + except PyMongoError as e: + logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") + # Keep the encrypted value if decryption fails + decrypted_doc[field_name] = value + decrypted_documents.append(decrypted_doc) + return decrypted_documents + + # If CSFLE is not configured, just return the raw documents + return documents def retrieve_by_id(self, id: str, memory_store_type: MemoryType) -> Optional[Dict[str, Any]]: """ @@ -907,7 +935,7 @@ def delete_all(self, memory_store_type: MemoryType) -> bool: def list_all(self, memory_store_type: MemoryType, include_embedding: bool = False) -> List[Dict[str, Any]]: """ - List all documents within a memory store type in MongoDB. + List all documents within a memory store type in MongoDB, decrypting encrypted fields. Parameters: ----------- @@ -924,26 +952,53 @@ def list_all(self, memory_store_type: MemoryType, include_embedding: bool = Fals # Define projection to exclude embeddings by default projection = {} if include_embedding else {"embedding": 0} + collection = None if memory_store_type == MemoryType.PERSONAS: - return list(self.persona_collection.find({}, projection)) + collection = self.persona_collection elif memory_store_type == MemoryType.TOOLBOX: - return list(self.toolbox_collection.find({}, projection)) + collection = self.toolbox_collection elif memory_store_type == MemoryType.SHORT_TERM_MEMORY: - return list(self.short_term_memory_collection.find({}, projection)) + collection = self.short_term_memory_collection elif memory_store_type == MemoryType.LONG_TERM_MEMORY: - return list(self.long_term_memory_collection.find({}, projection)) + collection = self.long_term_memory_collection elif memory_store_type == MemoryType.CONVERSATION_MEMORY: - return list(self.conversation_memory_collection.find({}, projection)) + collection = self.conversation_memory_collection elif memory_store_type == MemoryType.WORKFLOW_MEMORY: - return list(self.workflow_memory_collection.find({}, projection)) + collection = self.workflow_memory_collection elif memory_store_type == MemoryType.SHARED_MEMORY: - return list(self.shared_memory_collection.find({}, projection)) + collection = self.shared_memory_collection elif memory_store_type == MemoryType.SUMMARIES: - return list(self.summaries_collection.find({}, projection)) + collection = self.summaries_collection else: logger.warning(f"Unsupported memory store type for list_all: {memory_store_type}") return [] + # Retrieve all documents from the collection + documents = list(collection.find({}, projection)) + + # --- NEW: Decryption Logic --- + # Only attempt to decrypt if CSFLE is configured and initialized + if self._client_encryption: + decrypted_documents = [] + for doc in documents: + decrypted_doc = doc.copy() + for field_name, value in doc.items(): + # Check if the value is a Binary object (which indicates encryption) + if isinstance(value, Binary): + try: + # Attempt to decrypt the field + decrypted_value = self._client_encryption.decrypt(value) + decrypted_doc[field_name] = decrypted_value + except PyMongoError as e: + logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") + # Keep the encrypted value if decryption fails + decrypted_doc[field_name] = value + decrypted_documents.append(decrypted_doc) + return decrypted_documents + + # If CSFLE is not configured, just return the raw documents as they are + return documents + def update_by_id(self, id: str, data: Dict[str, Any], memory_store_type: MemoryType) -> bool: """ Update a document in a memory store type in MongoDB by _id. From fcab650522a718c8fc1e153415cf6a37ab72d387 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 16:04:19 -0400 Subject: [PATCH 3/6] add a small comment here --- src/memorizz/memory_provider/mongodb/provider.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index 78154a8..0166b44 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -79,6 +79,12 @@ def __init__(self, config): # Extract key vault details from the config kms_providers = self.config.encryption_config.get("kms_providers") key_vault_namespace = self.config.encryption_config.get("key_vault_namespace") + + """ + A separate keyvault client is a good idea for encryption because it enforces + a crucial security practice: the separation of data and keys. + """ + key_vault_client = self.config.encryption_config.get("key_vault_client") if not all([kms_providers, key_vault_namespace, key_vault_client]): From 6c855267ab40f7ad6c52152dcdfaa4367af71b11 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 16:23:54 -0400 Subject: [PATCH 4/6] master_key persistance for examples demo to run more than once! --- examples/encrypted-memory-mdb.py | 30 +++++++++++++------ .../memory_provider/mongodb/provider.py | 16 +++++----- 2 files changed, 28 insertions(+), 18 deletions(-) diff --git a/examples/encrypted-memory-mdb.py b/examples/encrypted-memory-mdb.py index 7951b8f..43d6740 100644 --- a/examples/encrypted-memory-mdb.py +++ b/examples/encrypted-memory-mdb.py @@ -7,10 +7,29 @@ from bson.binary import Binary, STANDARD from bson.codec_options import CodecOptions -local_master_key_string = os.urandom(96) +# fresh key every run: local_master_key_string = os.urandom(96) + +# --- Manage the Master Key persistently --- +# Define a path for the master key file +MASTER_KEY_FILE = "master_key.bin" + +# Check if the master key file exists +if os.path.exists(MASTER_KEY_FILE): + # Load the existing master key + with open(MASTER_KEY_FILE, "rb") as f: + local_master_key_string = f.read() + print("Loaded existing master key.") +else: + # Generate a new master key and save it to a file + local_master_key_string = os.urandom(96) + with open(MASTER_KEY_FILE, "wb") as f: + f.write(local_master_key_string) + print("Generated and saved new master key.") + os.environ["VOYAGE_API_KEY"] = "" +os.environ["OPENAI_API_KEY"] = "" -# --- Step 2: Configure KMS and Key Vault --- +# --- Configure KMS and Key Vault --- # The master key is used to encrypt the data keys. For this local demo, # we use a randomly generated 96-byte key. In a production environment, # this key would be managed by a secure Key Management Service (KMS) @@ -54,11 +73,6 @@ "function": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, "type": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random } - }, - "personas": { - "encrypted_fields": { - "background": Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random - } } } } @@ -145,8 +159,6 @@ def get_weather(latitude, longitude): } ) -os.environ["OPENAI_API_KEY"] = "" - toolbox = Toolbox( memory_provider=memory_provider, ) diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index 0166b44..3f6eac0 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -2,6 +2,7 @@ import logging from bson import ObjectId from pymongo import MongoClient +from pymongo.errors import PyMongoError from pymongo.encryption import ClientEncryption, Algorithm from bson.binary import Binary, STANDARD from bson.codec_options import CodecOptions @@ -79,14 +80,13 @@ def __init__(self, config): # Extract key vault details from the config kms_providers = self.config.encryption_config.get("kms_providers") key_vault_namespace = self.config.encryption_config.get("key_vault_namespace") + key_vault_client = self.config.encryption_config.get("key_vault_client") """ - A separate keyvault client is a good idea for encryption because it enforces - a crucial security practice: the separation of data and keys. + A separate keyvault client is a good idea for encryption because it enforces + a crucial security practice: the separation of data and keys. """ - key_vault_client = self.config.encryption_config.get("key_vault_client") - if not all([kms_providers, key_vault_namespace, key_vault_client]): raise ValueError("Incomplete encryption_config provided.") @@ -132,8 +132,6 @@ def __init__(self, config): # --- EXISTING CODE --- self.persona_collection = self.db[MemoryType.PERSONAS.value] self.toolbox_collection = self.db[MemoryType.TOOLBOX.value] - print("TOOLBOX"+MemoryType.TOOLBOX.value) - print("&^^^^^^^^") self.short_term_memory_collection = self.db[MemoryType.SHORT_TERM_MEMORY.value] self.long_term_memory_collection = self.db[MemoryType.LONG_TERM_MEMORY.value] self.conversation_memory_collection = self.db[MemoryType.CONVERSATION_MEMORY.value] @@ -161,6 +159,7 @@ def __init__(self, config): # Set lazy mode if immediate creation fails self.config.lazy_vector_indexes = True + def _setup_embedding_provider(self, config: MongoDBConfig): """ Setup the embedding provider based on configuration. @@ -362,7 +361,7 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: # Clean custom ID fields custom_id_fields = [ - "persona_id", "tool_id", "workflow_id", "short_term_memory_id", + "persona_id", "tool_id", "workflow_id", "short_term_memory_id", "agent_id" ] if memory_store_type != MemoryType.CONVERSATION_MEMORY: @@ -387,7 +386,7 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: # Explicitly encrypt the field's value encrypted_value = self._client_encryption.encrypt( value, - algorithm, + Algorithm(algorithm), key_alt_name=key_alt_name ) data_copy[field_name] = encrypted_value @@ -1004,7 +1003,6 @@ def list_all(self, memory_store_type: MemoryType, include_embedding: bool = Fals # If CSFLE is not configured, just return the raw documents as they are return documents - def update_by_id(self, id: str, data: Dict[str, Any], memory_store_type: MemoryType) -> bool: """ Update a document in a memory store type in MongoDB by _id. From c517e109c2bf54be87166b8be63b74bc324e7277 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 16:31:35 -0400 Subject: [PATCH 5/6] update comments a bit --- .../memory_provider/mongodb/provider.py | 82 +++++++++++++------ 1 file changed, 57 insertions(+), 25 deletions(-) diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index 3f6eac0..4003957 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -40,11 +40,15 @@ def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: boo If False, vector indexes are created immediately during initialization (requires embedding configuration). Default: False (maintains backward compatibility) embedding_provider : str or EmbeddingManager, optional - Embedding provider to use. + Embedding provider to use. Can be: + - EmbeddingManager instance (explicit injection) + - String provider name ("openai", "ollama", "voyageai") + - None (uses global embedding configuration) embedding_config : Dict[str, Any], optional - Configuration for the embedding provider. + Configuration for the embedding provider. Only used when embedding_provider is a string. + Example: {"model": "text-embedding-3-small", "dimensions": 512} encryption_config : Dict[str, Any], optional - Configuration for encryption settings. + Configuration for Client-Side Field Level Encryption (CSFLE) settings. """ self.uri = uri self.db_name = db_name @@ -58,7 +62,7 @@ def __init__(self, uri: str, db_name: str = "memorizz", lazy_vector_indexes: boo class MongoDBProvider(MemoryProvider): """MongoDB implementation of the MemoryProvider interface.""" - def __init__(self, config): + def __init__(self, config: MongoDBConfig): """ Initialize the MongoDB provider with configuration settings. @@ -71,7 +75,7 @@ def __init__(self, config): self.client = MongoClient(config.uri) self.db = self.client[config.db_name] - # --- NEW: Initialize CSFLE with smart checks and dynamic keymap creation --- + # --- Initialize CSFLE with smart checks and dynamic keymap creation --- self._client_encryption = None self.keymap = {} # Initialize the keymap as an empty dictionary @@ -311,11 +315,23 @@ def _create_vector_indexes_for_memory_stores(self) -> None: index_name="vector_index", memory_store=memory_store_present, ) - def _is_encrypted_field(self, collection_name: str, field_name: str) -> bool: + + def _is_encrypted_field(self, collection_name: str, field_name: str) -> Optional[str]: """ Checks if a field should be encrypted based on the encryption config. + + Parameters: + ----------- + collection_name : str + The name of the collection being checked. + field_name : str + The name of the field to check for encryption. + + Returns: + -------- + Optional[str] + The encryption algorithm as a string if the field should be encrypted, otherwise None. """ - if not self.config.encryption_config: return None @@ -332,7 +348,19 @@ def _is_encrypted_field(self, collection_name: str, field_name: str) -> bool: def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: """ - Store data in MongoDB, encrypting sensitive fields. + Store data in MongoDB, encrypting sensitive fields if CSFLE is configured. + + Parameters: + ----------- + data : Dict[str, Any] + The document to be stored. + memory_store_type : MemoryType + The type of memory store (e.g., "persona", "toolbox", etc.) + + Returns: + -------- + str + The ID of the inserted/updated document (MongoDB _id). """ collection = None collection_name = memory_store_type.value @@ -372,7 +400,7 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: for field in custom_id_fields: data_copy.pop(field, None) - # Encryption Logic: check if CSFLE is configured + # Encryption Logic: check if CSFLE is configured and a keymap exists if self._client_encryption and self.keymap: for field_name, value in list(data_copy.items()): algorithm = self._is_encrypted_field(collection_name, field_name) @@ -391,7 +419,7 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: ) data_copy[field_name] = encrypted_value - # Update or insert logic + # If document has MongoDB _id, update it (upsert) if "_id" in data_copy: result = collection.update_one( {"_id": data_copy["_id"]}, @@ -400,13 +428,14 @@ def store(self, data: Dict[str, Any], memory_store_type: MemoryType) -> str: ) return str(data_copy["_id"]) else: + # For new documents, let MongoDB generate _id automatically result = collection.insert_one(data_copy) return str(result.inserted_id) - def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType, limit: int = 1, include_embedding: bool = False) -> Optional[Dict[str, Any]]: + def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType, limit: int = 1, include_embedding: bool = False) -> Optional[List[Dict[str, Any]]]: """ - Retrieve a document from MongoDB. + Retrieve documents from MongoDB, decrypting fields if CSFLE is configured. Parameters: ----------- @@ -419,8 +448,8 @@ def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType Returns: -------- - Optional[Dict[str, Any]] - The retrieved document, or None if not found. + Optional[List[Dict[str, Any]]] + A list of retrieved and decrypted documents, or None if not found. """ # Define projection to exclude embeddings by default @@ -449,26 +478,28 @@ def retrieve_by_query(self, query: Dict[str, Any], memory_store_type: MemoryType return None # --- New Decryption Logic --- + # Only attempt decryption if CSFLE is enabled if self._client_encryption: decrypted_documents = [] for doc in documents: decrypted_doc = doc.copy() for field_name, value in doc.items(): - # Check if the value is a Binary object (which indicates encryption) + # Encrypted fields are stored as BSON Binary subtype 6 if isinstance(value, Binary): try: + # Attempt to decrypt the value decrypted_value = self._client_encryption.decrypt(value) decrypted_doc[field_name] = decrypted_value except PyMongoError as e: logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") - # Keep the encrypted value if decryption fails - decrypted_doc[field_name] = value + # Keep the encrypted value if decryption fails to avoid crashing + decrypted_doc[field_name] = value decrypted_documents.append(decrypted_doc) return decrypted_documents # If CSFLE is not configured, just return the raw documents return documents - + def retrieve_by_id(self, id: str, memory_store_type: MemoryType) -> Optional[Dict[str, Any]]: """ Retrieve a document from MongoDB by _id. @@ -940,7 +971,7 @@ def delete_all(self, memory_store_type: MemoryType) -> bool: def list_all(self, memory_store_type: MemoryType, include_embedding: bool = False) -> List[Dict[str, Any]]: """ - List all documents within a memory store type in MongoDB, decrypting encrypted fields. + List all documents within a memory store, decrypting fields if CSFLE is configured. Parameters: ----------- @@ -981,14 +1012,14 @@ def list_all(self, memory_store_type: MemoryType, include_embedding: bool = Fals # Retrieve all documents from the collection documents = list(collection.find({}, projection)) - # --- NEW: Decryption Logic --- + # --- Decryption Logic --- # Only attempt to decrypt if CSFLE is configured and initialized if self._client_encryption: decrypted_documents = [] for doc in documents: decrypted_doc = doc.copy() for field_name, value in doc.items(): - # Check if the value is a Binary object (which indicates encryption) + # Check if the value is a BSON Binary object (which indicates encryption) if isinstance(value, Binary): try: # Attempt to decrypt the field @@ -997,12 +1028,13 @@ def list_all(self, memory_store_type: MemoryType, include_embedding: bool = Fals except PyMongoError as e: logger.error(f"Failed to decrypt field '{field_name}' in document {doc.get('_id', '')}: {e}") # Keep the encrypted value if decryption fails - decrypted_doc[field_name] = value + decrypted_doc[field_name] = value decrypted_documents.append(decrypted_doc) return decrypted_documents - # If CSFLE is not configured, just return the raw documents as they are + # If CSFLE is not configured, just return the raw documents return documents + def update_by_id(self, id: str, data: Dict[str, Any], memory_store_type: MemoryType) -> bool: """ Update a document in a memory store type in MongoDB by _id. @@ -1565,7 +1597,7 @@ def delete_memagent(self, agent_id: str, cascade: bool = False) -> bool: True if deletion was successful, False otherwise. """ if cascade: - # Retrieve the memagent + # Retrieve the memagent memagent = self.retrieve_memagent(agent_id) if memagent is None: @@ -1619,7 +1651,7 @@ def _delete_memory_units_by_memory_id(self, memory_id: str, memory_type: MemoryT self.toolbox_collection.delete_many({"memory_id": memory_id}) elif memory_type == MemoryType.MEMAGENT: self.memagent_collection.delete_many({"memory_id": memory_id}) - + def _setup_vector_search_index(self, collection, index_name="vector_index", memory_store: bool = False): """ From b28a3a254aaae03bc183dcb5d4621f6151027563 Mon Sep 17 00:00:00 2001 From: ranfysvalle02 Date: Wed, 27 Aug 2025 16:35:25 -0400 Subject: [PATCH 6/6] remove extra line break --- src/memorizz/memory_provider/mongodb/provider.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/memorizz/memory_provider/mongodb/provider.py b/src/memorizz/memory_provider/mongodb/provider.py index 4003957..d2b50fa 100644 --- a/src/memorizz/memory_provider/mongodb/provider.py +++ b/src/memorizz/memory_provider/mongodb/provider.py @@ -1651,7 +1651,6 @@ def _delete_memory_units_by_memory_id(self, memory_id: str, memory_type: MemoryT self.toolbox_collection.delete_many({"memory_id": memory_id}) elif memory_type == MemoryType.MEMAGENT: self.memagent_collection.delete_many({"memory_id": memory_id}) - def _setup_vector_search_index(self, collection, index_name="vector_index", memory_store: bool = False): """