diff --git a/pkgs/core/swarmauri_core/crypto/types.py b/pkgs/core/swarmauri_core/crypto/types.py index 1b9dbd2f1e..0102ce0baa 100644 --- a/pkgs/core/swarmauri_core/crypto/types.py +++ b/pkgs/core/swarmauri_core/crypto/types.py @@ -62,6 +62,7 @@ class JWAAlg(str, Enum): RSA_OAEP = "RSA-OAEP" RSA_OAEP_256 = "RSA-OAEP-256" ECDH_ES = "ECDH-ES" + ECDH_ES_X25519_MLKEM768 = "ECDH-ES+X25519MLKEM768" DIR = "dir" A128GCM = "A128GCM" A192GCM = "A192GCM" diff --git a/pkgs/standards/swarmauri_crypto_jwe/pyproject.toml b/pkgs/standards/swarmauri_crypto_jwe/pyproject.toml index 9e20163480..6d5f69d86a 100644 --- a/pkgs/standards/swarmauri_crypto_jwe/pyproject.toml +++ b/pkgs/standards/swarmauri_crypto_jwe/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "swarmauri_core", "swarmauri_base", "cryptography>=41", + "pqcrypto>=0.3.1", ] keywords = [ 'swarmauri', diff --git a/pkgs/standards/swarmauri_crypto_jwe/swarmauri_crypto_jwe/JweCrypto.py b/pkgs/standards/swarmauri_crypto_jwe/swarmauri_crypto_jwe/JweCrypto.py index b55cd9f74e..8c866b8e8c 100644 --- a/pkgs/standards/swarmauri_crypto_jwe/swarmauri_crypto_jwe/JweCrypto.py +++ b/pkgs/standards/swarmauri_crypto_jwe/swarmauri_crypto_jwe/JweCrypto.py @@ -7,6 +7,7 @@ from __future__ import annotations import base64 +import binascii import json import os import zlib @@ -22,6 +23,8 @@ load_pem_private_key, ) +from pqcrypto.kem import kyber768 + from swarmauri_core.crypto.types import JWAAlg @@ -168,6 +171,38 @@ def _load_ecdh_recipient_public(jwk_or_pem: Any) -> Tuple[str, Any]: raise TypeError("Unsupported recipient public key format for ECDH-ES.") +def _bytes_from_any(value: Any, *, allow_mapping_key: str | None = None) -> bytes: + if allow_mapping_key and isinstance(value, Mapping): + if allow_mapping_key not in value: + raise ValueError( + f"Mapping is missing required key '{allow_mapping_key}' for ML-KEM-768" + ) + return _bytes_from_any(value[allow_mapping_key]) + if isinstance(value, (bytes, bytearray)): + return bytes(value) + if isinstance(value, str): + decoders = (_b64u_dec, lambda s: base64.b64decode(s, validate=False)) + for decoder in decoders: + try: + return decoder(value) + except (binascii.Error, ValueError): + continue + raise ValueError("Failed to decode ML-KEM key material from string.") + raise TypeError("Unsupported key material type; expected bytes or str.") + + +def _load_mlkem768_public(value: Any) -> bytes: + if value is None: + raise ValueError("ML-KEM-768 public key is required for hybrid encryption.") + return _bytes_from_any(value, allow_mapping_key="pub") + + +def _load_mlkem768_private(value: Any) -> bytes: + if value is None: + raise ValueError("ML-KEM-768 private key is required for hybrid decryption.") + return _bytes_from_any(value, allow_mapping_key="priv") + + def _concat_kdf( z: bytes, enc: JWAAlg, @@ -312,6 +347,54 @@ async def encrypt_compact( ) cek = _concat_kdf(z, enc, hashes.SHA256(), apu_b, apv_b) protected["epk"] = epk_header + elif alg == JWAAlg.ECDH_ES_X25519_MLKEM768: + x25519_info = key.get("x25519") + if x25519_info is None: + raise ValueError( + "Hybrid alg requires 'x25519' entry containing recipient public key." + ) + crv, rpk = _load_ecdh_recipient_public(x25519_info) + if crv != "X25519": + raise ValueError( + "Hybrid alg requires an X25519 recipient public key for the classical component." + ) + esk = x25519.X25519PrivateKey.generate() + epk = esk.public_key() + z_classical = esk.exchange(rpk) # type: ignore[arg-type] + epk_header = _x25519_jwk_from_public_key(epk) + + mlkem_pub = _load_mlkem768_public( + key.get("mlkem768") + or key.get("mlkem768_pub") + or key.get("pqc") + or key.get("mlkem") + ) + pqc_ciphertext, pqc_secret = kyber768.encapsulate(mlkem_pub) + + apu_b = None + apv_b = None + if "apu" in (header_extra or {}): + apu_b = ( + _b64u_dec(header_extra["apu"]) + if isinstance(header_extra["apu"], str) + else header_extra["apu"] + ) + if "apv" in (header_extra or {}): + apv_b = ( + _b64u_dec(header_extra["apv"]) + if isinstance(header_extra["apv"], str) + else header_extra["apv"] + ) + + cek = _concat_kdf( + z_classical + pqc_secret, enc, hashes.SHA256(), apu_b, apv_b + ) + protected["epk"] = epk_header + protected["pqc"] = { + "kty": "ML-KEM", + "kem": "ML-KEM-768", + "ct": _b64u(pqc_ciphertext), + } else: raise ValueError(f"Unsupported alg '{alg.value}'") @@ -343,6 +426,7 @@ async def decrypt_compact( rsa_private_pem: Optional[Union[str, bytes]] = None, rsa_private_password: Optional[Union[str, bytes]] = None, ecdh_private_key: Optional[Any] = None, + mlkem_private_key: Optional[Any] = None, expected_algs: Optional[Iterable[JWAAlg]] = None, expected_encs: Optional[Iterable[JWAAlg]] = None, aad: Optional[Union[bytes, str]] = None, @@ -355,6 +439,8 @@ async def decrypt_compact( RSA-OAEP algorithms. rsa_private_password (Union[str, bytes]): Password for the RSA key. ecdh_private_key (Any): Private key for ECDH-ES. + mlkem_private_key (Any): Private key for ML-KEM-768 when using the hybrid + algorithm. expected_algs (Iterable[JWAAlg]): Allowed algorithm values. expected_encs (Iterable[JWAAlg]): Allowed encryption values. aad (Union[bytes, str]): Additional authenticated data. @@ -439,6 +525,39 @@ async def decrypt_compact( apu_b = _b64u_dec(header["apu"]) if "apu" in header else None apv_b = _b64u_dec(header["apv"]) if "apv" in header else None cek = _concat_kdf(z, enc, hashes.SHA256(), apu_b, apv_b) + elif alg == JWAAlg.ECDH_ES_X25519_MLKEM768: + if not isinstance(ecdh_private_key, x25519.X25519PrivateKey): + raise TypeError( + "Hybrid alg requires an X25519PrivateKey for the classical component." + ) + if mlkem_private_key is None: + raise ValueError( + "mlkem_private_key is required for hybrid ECDH-ES+X25519MLKEM768 decryption." + ) + epk = header.get("epk") + if not (isinstance(epk, Mapping) and epk.get("kty") == "OKP"): + raise ValueError("Missing/invalid 'epk' for hybrid ECDH-ES header.") + if epk.get("crv") != "X25519": + raise ValueError("Hybrid 'epk' must declare crv='X25519'.") + z_classical = ecdh_private_key.exchange( + x25519.X25519PublicKey.from_public_bytes(_b64u_dec(epk["x"])) + ) + + pqc_info = header.get("pqc") + if not isinstance(pqc_info, Mapping): + raise ValueError("Missing 'pqc' object in hybrid header.") + ct_b64 = pqc_info.get("ct") + if not isinstance(ct_b64, str): + raise ValueError("Hybrid header 'pqc.ct' must be a base64url string.") + pqc_ciphertext = _b64u_dec(ct_b64) + mlkem_sk = _load_mlkem768_private(mlkem_private_key) + pqc_secret = kyber768.decapsulate(pqc_ciphertext, mlkem_sk) + + apu_b = _b64u_dec(header["apu"]) if "apu" in header else None + apv_b = _b64u_dec(header["apv"]) if "apv" in header else None + cek = _concat_kdf( + z_classical + pqc_secret, enc, hashes.SHA256(), apu_b, apv_b + ) else: raise ValueError(f"Unsupported alg '{alg.value}'") diff --git a/pkgs/standards/swarmauri_crypto_jwe/tests/test_hybrid_alg.py b/pkgs/standards/swarmauri_crypto_jwe/tests/test_hybrid_alg.py new file mode 100644 index 0000000000..5fd1df66de --- /dev/null +++ b/pkgs/standards/swarmauri_crypto_jwe/tests/test_hybrid_alg.py @@ -0,0 +1,65 @@ +import asyncio +import base64 +import json + +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import x25519 +from pqcrypto.kem import kyber768 + +from swarmauri_core.crypto.types import JWAAlg +from swarmauri_crypto_jwe import JweCrypto + + +def _b64u(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +@pytest.mark.unit +@pytest.mark.test +def test_ecdh_es_x25519_mlkem768_round_trip() -> None: + crypto = JweCrypto() + + recipient_x_priv = x25519.X25519PrivateKey.generate() + recipient_x_pub = recipient_x_priv.public_key() + recipient_x_pub_bytes = recipient_x_pub.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw, + ) + mlkem_pub, mlkem_priv = kyber768.generate_keypair() + + jwe = asyncio.run( + crypto.encrypt_compact( + payload=b"hybrid", + alg=JWAAlg.ECDH_ES_X25519_MLKEM768, + enc=JWAAlg.A256GCM, + key={ + "x25519": { + "kty": "OKP", + "crv": "X25519", + "x": _b64u(recipient_x_pub_bytes), + }, + "mlkem768": base64.b64encode(mlkem_pub).decode("ascii"), + }, + ) + ) + + protected_b64 = jwe.split(".")[0] + padding = "=" * ((4 - len(protected_b64) % 4) % 4) + protected = json.loads(base64.urlsafe_b64decode(protected_b64 + padding)) + + assert protected["alg"] == JWAAlg.ECDH_ES_X25519_MLKEM768.value + assert protected["enc"] == JWAAlg.A256GCM.value + assert protected["epk"]["crv"] == "X25519" + assert protected["pqc"]["kem"] == "ML-KEM-768" + assert isinstance(protected["pqc"]["ct"], str) + + result = asyncio.run( + crypto.decrypt_compact( + jwe, + ecdh_private_key=recipient_x_priv, + mlkem_private_key=mlkem_priv, + ) + ) + + assert result.plaintext == b"hybrid"