diff --git a/pgcrypto/__init__.py b/pgcrypto/__init__.py index 9981891..816bb61 100644 --- a/pgcrypto/__init__.py +++ b/pgcrypto/__init__.py @@ -1,11 +1,11 @@ DIGEST_SQL = "digest(%s, 'sha512')" HMAC_SQL = "hmac(%s, '{}', 'sha512')" -PGP_PUB_ENCRYPT_SQL_WITH_NULLIF = "pgp_pub_encrypt(nullif(%s, NULL)::text, dearmor('{}'))" +PGP_PUB_ENCRYPT_SQL_WITH_NULLIF = "pgp_pub_encrypt(nullif(%s, NULL)::text, {})" PGP_SYM_ENCRYPT_SQL_WITH_NULLIF = "pgp_sym_encrypt(nullif(%s, NULL)::text, '{}')" -PGP_PUB_ENCRYPT_SQL = "pgp_pub_encrypt(%s, dearmor('{}'))" +PGP_PUB_ENCRYPT_SQL = "pgp_pub_encrypt(%s, {})" PGP_SYM_ENCRYPT_SQL = "pgp_sym_encrypt(%s, '{}')" -PGP_PUB_DECRYPT_SQL = "pgp_pub_decrypt(%s, dearmor('{}'))::%s" +PGP_PUB_DECRYPT_SQL = "pgp_pub_decrypt(%s, {})::%s" PGP_SYM_DECRYPT_SQL = "pgp_sym_decrypt(%s, '{}')::%s" diff --git a/pgcrypto/mixins.py b/pgcrypto/mixins.py index 8d4e897..dd13ab3 100644 --- a/pgcrypto/mixins.py +++ b/pgcrypto/mixins.py @@ -10,14 +10,36 @@ ) -def get_setting(connection, key): +def get_setting(connection, key, **kwargs): """Get key from connection or default to settings.""" if key in connection.settings_dict: return connection.settings_dict[key] else: + if 'default' in kwargs: + return getattr(settings, key, kwargs['default']) return getattr(settings, key) +def get_pgp_public_key_sql(connection): + dearmored_file = get_setting(connection, 'PUBLIC_PGP_KEY_DEARMORED_FILE', default=None) + if dearmored_file is not None: + return "pg_read_binary_file('{}')".format(dearmored_file) + armored_file = get_setting(connection, 'PUBLIC_PGP_KEY_ARMORED_FILE', default=None) + if armored_file is not None: + return "dearmor(pg_read_file('{}'))".format(armored_file) + return "dearmor('{}')".format(get_setting(connection, 'PUBLIC_PGP_KEY')) + + +def get_pgp_private_key_sql(connection): + dearmored_file = get_setting(connection, 'PRIVATE_PGP_KEY_DEARMORED_FILE', default=None) + if dearmored_file is not None: + return "pg_read_binary_file('{}')".format(dearmored_file) + armored_file = get_setting(connection, 'PRIVATE_PGP_KEY_ARMORED_FILE', default=None) + if armored_file is not None: + return "dearmor(pg_read_file('{}'))".format(armored_file) + return "dearmor('{}')".format(get_setting(connection, 'PRIVATE_PGP_KEY')) + + class DecryptedCol(Col): """Provide DecryptedCol support without using `extra` sql.""" @@ -133,11 +155,11 @@ class PGPPublicKeyFieldMixin(PGPMixin): def get_placeholder(self, value=None, compiler=None, connection=None): """Tell postgres to encrypt this field using PGP.""" - return self.encrypt_sql.format(get_setting(connection, 'PUBLIC_PGP_KEY')) + return self.encrypt_sql.format(get_pgp_public_key_sql(connection)) def get_decrypt_sql(self, connection): """Get decrypt sql.""" - return self.decrypt_sql.format(get_setting(connection, 'PRIVATE_PGP_KEY')) + return self.decrypt_sql.format(get_pgp_private_key_sql(connection)) class PGPSymmetricKeyFieldMixin(PGPMixin):