diff --git a/pytpcc/MONGODB_EXAMPLE b/pytpcc/MONGODB_EXAMPLE index 4b62bc2..92ec5ab 100644 --- a/pytpcc/MONGODB_EXAMPLE +++ b/pytpcc/MONGODB_EXAMPLE @@ -11,3 +11,4 @@ denormalize = True retry_writes = True # user = username # passwd = passwd +fle = False diff --git a/pytpcc/drivers/mongodbdriver.py b/pytpcc/drivers/mongodbdriver.py index cd4dd32..8b5bd2a 100644 --- a/pytpcc/drivers/mongodbdriver.py +++ b/pytpcc/drivers/mongodbdriver.py @@ -38,6 +38,11 @@ from pprint import pformat from time import sleep import pymongo +from bson import binary +import base64 +from pymongo.encryption import (Algorithm, + ClientEncryption) +from pymongo.encryption_options import AutoEncryptionOpts import constants from abstractdriver import AbstractDriver @@ -82,12 +87,12 @@ "C_FIRST", # VARCHAR "C_MIDDLE", # VARCHAR "C_LAST", # VARCHAR - "C_STREET_1", # VARCHAR - "C_STREET_2", # VARCHAR - "C_CITY", # VARCHAR - "C_STATE", # VARCHAR - "C_ZIP", # VARCHAR - "C_PHONE", # VARCHAR + "C_STREET_1", # VARCHAR CONFIDENTIAL + "C_STREET_2", # VARCHAR CONFIDENTIAL + "C_CITY", # VARCHAR CONFIDENTIAL + "C_STATE", # VARCHAR CONFIDENTIAL + "C_ZIP", # VARCHAR CONFIDENTIAL + "C_PHONE", # VARCHAR CONFIDENTIAL "C_SINCE", # TIMESTAMP "C_CREDIT", # VARCHAR "C_CREDIT_LIM", # FLOAT @@ -96,7 +101,7 @@ "C_YTD_PAYMENT", # FLOAT "C_PAYMENT_CNT", # INTEGER "C_DELIVERY_CNT", # INTEGER - "C_DATA", # VARCHAR + "C_DATA", # VARCHAR CONFIDENTIAL ], constants.TABLENAME_STOCK: [ "S_I_ID", # INTEGER @@ -200,7 +205,8 @@ class MongodbDriver(AbstractDriver): "secondary_reads": ("If true, we will allow secondary reads", True), "retry_writes": ("If true, we will enable retryable writes", True), "causal_consistency": ("If true, we will perform causal reads ", True), - "shards": ("If >1 then sharded", "1") + "shards": ("If >1 then sharded", "1"), + "fle": ("If true, confidential data will be encrypted using CSFLE", False) } DENORMALIZED_TABLES = [ constants.TABLENAME_ORDERS, @@ -215,7 +221,6 @@ def __init__(self, ddl): self.read_preference = "primary" self.database = None self.client = None - self.executed = False self.w_orders = {} # things that are not better can't be set in config self.batch_writes = True @@ -232,6 +237,9 @@ def __init__(self, ddl): self.result_doc = {} self.warehouses = 0 self.shards = 1 + self.use_encryption = 0 + self.load = False + self.execute = False ## Create member mapping to collections for name in constants.ALL_TABLES: @@ -268,6 +276,9 @@ def loadConfig(self, config): self.secondary_reads = config['secondary_reads'] == 'True' if self.secondary_reads: self.read_preference = "nearest" + self.use_encryption = config['fle'] == 'True' + self.load = config['load'] == 'True' + self.execute = config['execute'] == 'True' if 'write_concern' in config and config['write_concern'] and config['write_concern'] != '1': # only expecting string 'majority' as an alternative to w:1 @@ -297,10 +308,25 @@ def loadConfig(self, config): real_uri = uri[0:pindex]+userpassword+uri[pindex:] display_uri = uri[0:pindex]+usersecret+uri[pindex:] + # Encryption - FLE + auto_encryption_opts = None + if self.use_encryption: + local_master_key = binary.Binary(base64.b64decode( + 'YB82/JCPNOcNr1NRMVojIWVTHv1EF7uI5VcNs+jTg9NAzBMLQ1b3kQ3BhsnLza9DZT2tOuj6jeVYO890s18LwkfJRaPFrx5FhcZmmhM5wYkNw/IO0PVF7Z9+pNPB3EOw')) + kms_providers = {"local": {"key": local_master_key}} + + key_vault_namespace = "encryption.dataKeys" + key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1) + + auto_encryption_opts = AutoEncryptionOpts( + kms_providers, key_vault_namespace, bypass_auto_encryption=True) + ## IF + self.client = pymongo.MongoClient(real_uri, retryWrites=self.retry_writes, readPreference=self.read_preference, - readConcernLevel=self.read_concern) + readConcernLevel=self.read_concern, + auto_encryption_opts=auto_encryption_opts) self.result_doc['before']=self.get_server_status() @@ -333,6 +359,40 @@ def loadConfig(self, config): uniq = False ## IF ## FOR + + if self.use_encryption: + if not config['load'] and not config['execute'] and config["reset"]: + key_vault = self.client[key_vault_db_name][key_vault_coll_name] + key_vault.drop() + key_vault.create_index( + "keyAltNames", + unique=True, + partialFilterExpression={"keyAltNames": {"$exists": True}}) + ## IF + + self.client_encryption = ClientEncryption( + kms_providers, + key_vault_namespace, + self.client, + self.database.codec_options) + + if not config['load'] and not config['execute'] and config["reset"]: + self.client_encryption.create_data_key('local', + key_alt_names=['C_STREET_1_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_STREET_2_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_CITY_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_STATE_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_ZIP_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_PHONE_fle_data_key']) + self.client_encryption.create_data_key('local', + key_alt_names=['C_DATA_fle_data_key']) + ## IF + ## IF except pymongo.errors.OperationFailure as exc: logging.error("OperationFailure %d (%s) when connected to %s: ", exc.code, exc.details, display_uri) @@ -345,6 +405,9 @@ def loadConfig(self, config): logging.error("ConnectionFailure %d (%s) when connected to %s: ", exc.code, exc.details, display_uri) return + except pymongo.errors.EncryptionError as err: + logging.error("EncryptionError (%s) when connected to %s: ", str(err), display_uri) + return except pymongo.errors.PyMongoError, err: logging.error("Some general error (%s) when connected to %s: ", str(err), display_uri) print "Got some other error: %s" % str(err) @@ -404,6 +467,39 @@ def loadTuples(self, tableName, tuples): t2.append(w) tuples3.append(t2) tuples = tuples3 + + # If this is an ORDER_LINE record, then we need to encrypt confidential fields (if needed) + elif tableName == constants.TABLENAME_CUSTOMER and self.use_encryption: + for t in tuples: + t[6] = self.client_encryption.encrypt( + t[6], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_STREET_1_fle_data_key') + t[7] = self.client_encryption.encrypt( + t[7], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_STREET_2_fle_data_key') + t[8] = self.client_encryption.encrypt( + t[8], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_CITY_fle_data_key') + t[9] = self.client_encryption.encrypt( + t[9], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_STATE_fle_data_key') + t[10] = self.client_encryption.encrypt( + t[10], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_ZIP_fle_data_key') + t[11] = self.client_encryption.encrypt( + t[11], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_PHONE_fle_data_key') + t[20] = self.client_encryption.encrypt( + t[20], + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_DATA_fle_data_key') + for t in tuples: tuple_dicts.append(dict([(columns[i], t[i]) for i in num_columns])) ## FOR @@ -420,13 +516,23 @@ def loadFinishDistrict(self, w_id, d_id): self.w_orders.clear() ## IF + def loadFinish(self): + """Optional callback to indicate to the driver that the data loading phase is finished.""" + if self.load: + if self.use_encryption: + self.client_encryption.close() + self.client.close() + def executeStart(self): """Optional callback before the execution for each client starts""" return None def executeFinish(self): """Callback after the execution for each client finishes""" - return None + if self.execute: + if self.use_encryption: + self.client_encryption.close() + self.client.close() ## ---------------------------------------------- ## doDelivery @@ -812,7 +918,7 @@ def _doOrderStatusTxn(self, s, params): search_fields["C_ID"] = c_id c = self.customer.find_one(search_fields, return_fields, session=s) assert c, "Couldn't find customer in order status" - else: + elif c_last != None: # getCustomersByLastName # Get the midpoint customer's id search_fields['C_LAST'] = c_last @@ -935,7 +1041,7 @@ def _doPaymentTxn(self, s, params): search_fields["C_ID"] = c_id c = self.customer.find_one(search_fields, return_fields, session=s) assert c, "No customer in payment w_id %d d_id %d c_id %d" % (w_id, d_id, c_id) - else: + elif c_last != None: # getCustomersByLastName # Get the midpoint customer's id search_fields['C_LAST'] = c_last @@ -962,6 +1068,11 @@ def _doPaymentTxn(self, s, params): c_data = (new_data + "|" + c_data) if len(c_data) > constants.MAX_C_DATA: c_data = c_data[:constants.MAX_C_DATA] + if self.use_encryption: + c_data = self.client_encryption.encrypt( + c_data, + Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random, + key_alt_name='C_DATA_fle_data_key') customer_update["$set"] = {"C_DATA": c_data} ## IF