Skip to content

Add Client-Side Field Level Encryption (CSFLE) feature to mongodbdriver #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytpcc/MONGODB_EXAMPLE
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ denormalize = True
retry_writes = True
# user = username
# passwd = passwd
fle = False
137 changes: 124 additions & 13 deletions pytpcc/drivers/mongodbdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
from pprint import pformat
from time import sleep
import pymongo
from bson import binary
import base64
from pymongo.encryption import (Algorithm,
ClientEncryption)
from pymongo.encryption_options import AutoEncryptionOpts

import constants
from abstractdriver import AbstractDriver
Expand Down Expand Up @@ -82,12 +87,12 @@
"C_FIRST", # VARCHAR
"C_MIDDLE", # VARCHAR
"C_LAST", # VARCHAR
"C_STREET_1", # VARCHAR
"C_STREET_2", # VARCHAR
"C_CITY", # VARCHAR
"C_STATE", # VARCHAR
"C_ZIP", # VARCHAR
"C_PHONE", # VARCHAR
"C_STREET_1", # VARCHAR CONFIDENTIAL
"C_STREET_2", # VARCHAR CONFIDENTIAL
"C_CITY", # VARCHAR CONFIDENTIAL
"C_STATE", # VARCHAR CONFIDENTIAL
"C_ZIP", # VARCHAR CONFIDENTIAL
"C_PHONE", # VARCHAR CONFIDENTIAL
"C_SINCE", # TIMESTAMP
"C_CREDIT", # VARCHAR
"C_CREDIT_LIM", # FLOAT
Expand All @@ -96,7 +101,7 @@
"C_YTD_PAYMENT", # FLOAT
"C_PAYMENT_CNT", # INTEGER
"C_DELIVERY_CNT", # INTEGER
"C_DATA", # VARCHAR
"C_DATA", # VARCHAR CONFIDENTIAL
],
constants.TABLENAME_STOCK: [
"S_I_ID", # INTEGER
Expand Down Expand Up @@ -200,7 +205,8 @@ class MongodbDriver(AbstractDriver):
"secondary_reads": ("If true, we will allow secondary reads", True),
"retry_writes": ("If true, we will enable retryable writes", True),
"causal_consistency": ("If true, we will perform causal reads ", True),
"shards": ("If >1 then sharded", "1")
"shards": ("If >1 then sharded", "1"),
"fle": ("If true, confidential data will be encrypted using CSFLE", False)
}
DENORMALIZED_TABLES = [
constants.TABLENAME_ORDERS,
Expand All @@ -215,7 +221,6 @@ def __init__(self, ddl):
self.read_preference = "primary"
self.database = None
self.client = None
self.executed = False
self.w_orders = {}
# things that are not better can't be set in config
self.batch_writes = True
Expand All @@ -232,6 +237,9 @@ def __init__(self, ddl):
self.result_doc = {}
self.warehouses = 0
self.shards = 1
self.use_encryption = 0
self.load = False
self.execute = False

## Create member mapping to collections
for name in constants.ALL_TABLES:
Expand Down Expand Up @@ -268,6 +276,9 @@ def loadConfig(self, config):
self.secondary_reads = config['secondary_reads'] == 'True'
if self.secondary_reads:
self.read_preference = "nearest"
self.use_encryption = config['fle'] == 'True'
self.load = config['load'] == 'True'
self.execute = config['execute'] == 'True'

if 'write_concern' in config and config['write_concern'] and config['write_concern'] != '1':
# only expecting string 'majority' as an alternative to w:1
Expand Down Expand Up @@ -297,10 +308,25 @@ def loadConfig(self, config):
real_uri = uri[0:pindex]+userpassword+uri[pindex:]
display_uri = uri[0:pindex]+usersecret+uri[pindex:]

# Encryption - FLE
auto_encryption_opts = None
if self.use_encryption:
local_master_key = binary.Binary(base64.b64decode(
'YB82/JCPNOcNr1NRMVojIWVTHv1EF7uI5VcNs+jTg9NAzBMLQ1b3kQ3BhsnLza9DZT2tOuj6jeVYO890s18LwkfJRaPFrx5FhcZmmhM5wYkNw/IO0PVF7Z9+pNPB3EOw'))
kms_providers = {"local": {"key": local_master_key}}

key_vault_namespace = "encryption.dataKeys"
key_vault_db_name, key_vault_coll_name = key_vault_namespace.split(".", 1)

auto_encryption_opts = AutoEncryptionOpts(
kms_providers, key_vault_namespace, bypass_auto_encryption=True)
## IF

self.client = pymongo.MongoClient(real_uri,
retryWrites=self.retry_writes,
readPreference=self.read_preference,
readConcernLevel=self.read_concern)
readConcernLevel=self.read_concern,
auto_encryption_opts=auto_encryption_opts)

self.result_doc['before']=self.get_server_status()

Expand Down Expand Up @@ -333,6 +359,40 @@ def loadConfig(self, config):
uniq = False
## IF
## FOR

if self.use_encryption:
if not config['load'] and not config['execute'] and config["reset"]:
key_vault = self.client[key_vault_db_name][key_vault_coll_name]
key_vault.drop()
key_vault.create_index(
"keyAltNames",
unique=True,
partialFilterExpression={"keyAltNames": {"$exists": True}})
## IF

self.client_encryption = ClientEncryption(
kms_providers,
key_vault_namespace,
self.client,
self.database.codec_options)

if not config['load'] and not config['execute'] and config["reset"]:
self.client_encryption.create_data_key('local',
key_alt_names=['C_STREET_1_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_STREET_2_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_CITY_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_STATE_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_ZIP_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_PHONE_fle_data_key'])
self.client_encryption.create_data_key('local',
key_alt_names=['C_DATA_fle_data_key'])
## IF
## IF
except pymongo.errors.OperationFailure as exc:
logging.error("OperationFailure %d (%s) when connected to %s: ",
exc.code, exc.details, display_uri)
Expand All @@ -345,6 +405,9 @@ def loadConfig(self, config):
logging.error("ConnectionFailure %d (%s) when connected to %s: ",
exc.code, exc.details, display_uri)
return
except pymongo.errors.EncryptionError as err:
logging.error("EncryptionError (%s) when connected to %s: ", str(err), display_uri)
return
except pymongo.errors.PyMongoError, err:
logging.error("Some general error (%s) when connected to %s: ", str(err), display_uri)
print "Got some other error: %s" % str(err)
Expand Down Expand Up @@ -404,6 +467,39 @@ def loadTuples(self, tableName, tuples):
t2.append(w)
tuples3.append(t2)
tuples = tuples3

# If this is an ORDER_LINE record, then we need to encrypt confidential fields (if needed)
elif tableName == constants.TABLENAME_CUSTOMER and self.use_encryption:
for t in tuples:
t[6] = self.client_encryption.encrypt(
t[6],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_STREET_1_fle_data_key')
t[7] = self.client_encryption.encrypt(
t[7],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_STREET_2_fle_data_key')
t[8] = self.client_encryption.encrypt(
t[8],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_CITY_fle_data_key')
t[9] = self.client_encryption.encrypt(
t[9],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_STATE_fle_data_key')
t[10] = self.client_encryption.encrypt(
t[10],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_ZIP_fle_data_key')
t[11] = self.client_encryption.encrypt(
t[11],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_PHONE_fle_data_key')
t[20] = self.client_encryption.encrypt(
t[20],
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_DATA_fle_data_key')

for t in tuples:
tuple_dicts.append(dict([(columns[i], t[i]) for i in num_columns]))
## FOR
Expand All @@ -420,13 +516,23 @@ def loadFinishDistrict(self, w_id, d_id):
self.w_orders.clear()
## IF

def loadFinish(self):
"""Optional callback to indicate to the driver that the data loading phase is finished."""
if self.load:
if self.use_encryption:
self.client_encryption.close()
self.client.close()

def executeStart(self):
"""Optional callback before the execution for each client starts"""
return None

def executeFinish(self):
"""Callback after the execution for each client finishes"""
return None
if self.execute:
if self.use_encryption:
self.client_encryption.close()
self.client.close()

## ----------------------------------------------
## doDelivery
Expand Down Expand Up @@ -812,7 +918,7 @@ def _doOrderStatusTxn(self, s, params):
search_fields["C_ID"] = c_id
c = self.customer.find_one(search_fields, return_fields, session=s)
assert c, "Couldn't find customer in order status"
else:
elif c_last != None:
# getCustomersByLastName
# Get the midpoint customer's id
search_fields['C_LAST'] = c_last
Expand Down Expand Up @@ -935,7 +1041,7 @@ def _doPaymentTxn(self, s, params):
search_fields["C_ID"] = c_id
c = self.customer.find_one(search_fields, return_fields, session=s)
assert c, "No customer in payment w_id %d d_id %d c_id %d" % (w_id, d_id, c_id)
else:
elif c_last != None:
# getCustomersByLastName
# Get the midpoint customer's id
search_fields['C_LAST'] = c_last
Expand All @@ -962,6 +1068,11 @@ def _doPaymentTxn(self, s, params):
c_data = (new_data + "|" + c_data)
if len(c_data) > constants.MAX_C_DATA:
c_data = c_data[:constants.MAX_C_DATA]
if self.use_encryption:
c_data = self.client_encryption.encrypt(
c_data,
Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random,
key_alt_name='C_DATA_fle_data_key')
customer_update["$set"] = {"C_DATA": c_data}
## IF

Expand Down