diff --git a/crypto_utils.py b/crypto_utils.py new file mode 100644 index 0000000000..01eb0fa5b4 --- /dev/null +++ b/crypto_utils.py @@ -0,0 +1,95 @@ +import base64 +from pathlib import Path +from typing import BinaryIO, Union + +from cryptography.fernet import Fernet +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from cryptography.hazmat.primitives.hashes import SHA256 +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from cryptography.hazmat.primitives.padding import PKCS7 +from cryptography.hazmat.primitives import hmac + +class CryptoUtils: + @staticmethod + def decrypt_fernet(data: bytes, key: str) -> bytes: + """ + Decrypt data using Fernet symmetric encryption. + + :param data: Encrypted data. + :param key: Base64-encoded 32-byte key. + :return: Decrypted data. + """ + f = Fernet(base64.b64decode(key)) + return f.decrypt(data) + + @staticmethod + def decrypt_aes_gcm(data: bytes, nonce: bytes, key: bytes) -> bytes: + """ + Decrypt data using AES GCM symmetric encryption. + + :param data: Encrypted data. + :param nonce: Nonce used for encryption. + :param key: 32-byte encryption key. + :return: Decrypted data. + """ + aes_gcm = AESGCM(key) + return aes_gcm.decrypt(nonce, data, None) + + @staticmethod + def decrypt_hkdf_hmac_sha256(data: bytes, salt: bytes, key: bytes) -> bytes: + """ + Decrypt data using HKDF-HMAC-SHA256 key derivation and symmetric encryption. + + :param data: Encrypted data. + :param salt: Salt used for key derivation. + :param key: 32-byte encryption key. + :return: Decrypted data. + """ + kdf = HKDF( + algorithm=SHA256(), + length=32, + salt=salt, + info=None, + backend=None, + ) + derived_key = kdf.derive(key) + + decryptor = hmac.HMAC(derived_key, SHA256(), backend=None) + decryptor.update(data[:16]) + tag = decryptor.finalize() + + pkcs7_padding = PKCS7(128).unpad(data[16:]) + + if hmac.new(derived_key, pkcs7_padding, SHA256()).digest() != tag: + raise ValueError("Invalid tag.") + + return pkcs7_padding + + @staticmethod + def decrypt_file(file_path: Union[str, Path], password: str) -> bytes: + """ + Decrypt a file using the provided password. + + :param file_path: Path to the encrypted file. + :param password: Password used for decryption. + :return: Decrypted file content. + """ + with open(file_path, "rb") as file: + file_content = file.read() + + encryption_method = file_content[:4].decode("utf-8") + + if encryption_method == "fern": + return CryptoUtils.decrypt_fernet(file_content[4:], password.encode()) + + if encryption_method == "aesg": + nonce = file_content[4:16] + encrypted_data = file_content[16:] + return CryptoUtils.decrypt_aes_gcm(encrypted_data, nonce, password.encode()) + + if encryption_method == "hkdf": + salt = file_content[4:16] + encrypted_data = file_content[16:] + return CryptoUtils.decrypt_hkdf_hmac_sha256(encrypted_data, salt, password.encode()) + + raise ValueError(f"Unsupported encryption method: {encryption_method}") \ No newline at end of file diff --git a/tests/test_crypto_utils.py b/tests/test_crypto_utils.py new file mode 100644 index 0000000000..cc11521bc9 --- /dev/null +++ b/tests/test_crypto_utils.py @@ -0,0 +1,35 @@ + import os +from cryptography.fernet import Fernet + + +def generate_key(): + """ + Generates a new encryption key. + + Returns: + bytes: The encryption key. + """ + return Fernet.generate_key() + + +def encrypt_file(file_path: str, key: bytes): + """ + Encrypts a file using the given key. + + Args: + file_path (str): The path to the file to encrypt. + key (bytes): The encryption key. + + Returns: + None + """ + f = Fernet(key) + with open(file_path, "rb") as file: + data = file.read() + encrypted_data = f.encrypt(data) + with open(file_path, "wb") as file: + file.write(encrypted_data) + + +def decrypt_file(file_path: str, key: bytes): + """