diff --git a/Cargo.lock b/Cargo.lock index 7324776..0329fff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,7 +23,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "generic-array", ] @@ -212,6 +212,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.11.0-rc.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9ef36a6fcdb072aa548f3da057640ec10859eb4e91ddf526ee648d50c76a949" +dependencies = [ + "hybrid-array", +] + [[package]] name = "bon" version = "3.7.0" @@ -298,7 +307,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "inout", ] @@ -431,6 +440,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8235645834fbc6832939736ce2f2d08192652269e11010a6240f61b908a1c6" +dependencies = [ + "hybrid-array", +] + [[package]] name = "ctr" version = "0.9.2" @@ -440,6 +458,32 @@ dependencies = [ "cipher", ] +[[package]] +name = "curve25519-dalek" +version = "4.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" +dependencies = [ + "cfg-if", + "cpufeatures", + "curve25519-dalek-derive", + "fiat-crypto", + "rustc_version", + "subtle", + "zeroize", +] + +[[package]] +name = "curve25519-dalek-derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "darling" version = "0.20.11" @@ -537,12 +581,22 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", - "crypto-common", + "crypto-common 0.1.6", "subtle", ] +[[package]] +name = "digest" +version = "0.11.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dac89f8a64533a9b0eaa73a68e424db0fb1fd6271c74cc0125336a05f090568d" +dependencies = [ + "block-buffer 0.11.0-rc.5", + "crypto-common 0.2.0-rc.4", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -573,7 +627,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ "der", - "digest", + "digest 0.10.7", "elliptic-curve", "rfc6979", "signature", @@ -597,7 +651,7 @@ checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ "base16ct", "crypto-bigint", - "digest", + "digest 0.10.7", "ff", "generic-array", "group", @@ -664,6 +718,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "fiat-crypto" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28dea519a9695b9977216879a3ebfddf92f1c08c05d984f8996aecd6ecdc811d" + [[package]] name = "flume" version = "0.11.1" @@ -1097,7 +1157,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "digest", + "digest 0.10.7", ] [[package]] @@ -1177,6 +1237,16 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hybrid-array" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f471e0a81b2f90ffc0cb2f951ae04da57de8baa46fa99112b062a5173a5088d0" +dependencies = [ + "subtle", + "typenum", +] + [[package]] name = "hyper" version = "0.14.32" @@ -1586,6 +1656,25 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "keccak" +version = "0.2.0-rc.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d546793a04a1d3049bd192856f804cfe96356e2cf36b54b4e575155babe9f41" +dependencies = [ + "cpufeatures", +] + +[[package]] +name = "kem" +version = "0.3.0-pre.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4701a9c37a0843da68e19189250e8c62276fb551fe55bde787e9da480c45ee59" +dependencies = [ + "rand_core 0.9.3", + "zeroize", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1667,7 +1756,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ "cfg-if", - "digest", + "digest 0.10.7", ] [[package]] @@ -1702,6 +1791,19 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ml-kem" +version = "0.3.0-pre.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95ce26c413d3c2cd89fb1615c1ab7882f6cd429a6b3f1e184042cb9ec2745a8" +dependencies = [ + "hybrid-array", + "kem", + "rand_core 0.9.3", + "sha3", + "subtle", +] + [[package]] name = "num-bigint-dig" version = "0.8.4" @@ -2194,7 +2296,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47c75d7c5c6b673e58bf54d8544a9f432e3a925b0e80f7cd3602ab5c50c55519" dependencies = [ "const-oid", - "digest", + "digest 0.10.7", "num-bigint-dig", "num-integer", "num-traits", @@ -2449,7 +2551,7 @@ checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.7", ] [[package]] @@ -2460,7 +2562,17 @@ checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.7", +] + +[[package]] +name = "sha3" +version = "0.11.0-rc.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2103ca0e6f4e9505eae906de5e5883e06fc3b2232fb5d6914890c7bbcb62f478" +dependencies = [ + "digest 0.11.0-rc.3", + "keccak", ] [[package]] @@ -2484,7 +2596,7 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ - "digest", + "digest 0.10.7", "rand_core 0.6.4", ] @@ -2654,7 +2766,7 @@ dependencies = [ "bytes", "chrono", "crc", - "digest", + "digest 0.10.7", "dotenvy", "either", "futures-channel", @@ -2817,19 +2929,24 @@ dependencies = [ "base64 0.21.7", "clap", "google-cloud-secretmanager-v1", + "hkdf", "hyper 0.14.32", "hyperlocal", "jsonrpsee", "lru", + "ml-kem", "p256", + "rand 0.9.2", "rand_core 0.6.4", "serde", "serde_bytes", "serde_json", "serde_plain", + "sha2", "sqlx", "tokio", "uuid", + "x25519-dalek", ] [[package]] @@ -3142,7 +3259,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "subtle", ] @@ -3639,6 +3756,18 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +[[package]] +name = "x25519-dalek" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7e468321c81fb07fa7f4c636c3972b9100f0346e5b6a9f2bd0603a52f7ed277" +dependencies = [ + "curve25519-dalek", + "rand_core 0.6.4", + "serde", + "zeroize", +] + [[package]] name = "yoke" version = "0.7.5" @@ -3710,6 +3839,20 @@ name = "zeroize" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zerovec" diff --git a/Cargo.toml b/Cargo.toml index 6edff6a..b4f622e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,8 +25,17 @@ rand_core = "0.6.4" serde_plain = "1.0.2" lru = "0.13.0" +# PQXDH dependencies +x25519-dalek = "2.0" +ml-kem = "0.3.0-pre.1" # ML-KEM standard (compatible with @noble/post-quantum) +hkdf = "0.12" +sha2 = "0.10" +rand = "0.9.0-beta.2" # For ml-kem's rand_core 0.9 dependency + [features] register = [] dsc = [] disclose = [] cherrypick = [] +# Test mode: mocks attestation service for local testing +test_mode = [] diff --git a/examples/pqxdh_test_server.rs b/examples/pqxdh_test_server.rs new file mode 100644 index 0000000..ee6f190 --- /dev/null +++ b/examples/pqxdh_test_server.rs @@ -0,0 +1,425 @@ +/// Minimal TEE server for PQXDH handshake testing. +/// This server only implements hello and key_exchange methods without requiring: +/// - TEE attestation service (mocked) +/// - PostgreSQL database +/// - Circuit/zkey files +/// - Proof generation +/// +/// Run with: cargo run --example pqxdh_test_server --features test_mode +/// +/// The server will listen on http://127.0.0.1:9944 by default. + +use jsonrpsee::core::async_trait; +use jsonrpsee::proc_macros::rpc; +use jsonrpsee::server::Server; +use jsonrpsee::types; +use jsonrpsee::{types::ErrorObjectOwned, ResponsePayload}; +use std::sync::Arc; + +use tee_server::store::{KeyMaterial, LruStore}; +use tee_server::types::HelloResponse; + +// importing PQXDH dependencies +use base64::engine::{general_purpose, Engine}; +use hkdf::Hkdf; +use ml_kem::kem::Decapsulate; +use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem768}; +use p256::ecdh::EphemeralSecret; +use p256::elliptic_curve::sec1::ToEncodedPoint; +use p256::elliptic_curve::PublicKey; +use rand_core::OsRng; +use sha2::Sha256; +use x25519_dalek::{EphemeralSecret as X25519Secret, PublicKey as X25519PublicKey}; + +// mock attestation function for test mode +async fn get_mock_attestation(nonces: Vec<&str>) -> Result, Box> { + let mock_header = r#"{"alg":"RS256","typ":"JWT"}"#; + let mock_payload = format!( + r#"{{"nonces":[{}],"iat":{},"exp":{}}}"#, + nonces + .iter() + .map(|n| format!("\"{}\"", n)) + .collect::>() + .join(","), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs(), + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap() + .as_secs() + + 3600, + ); + + let encoded_header = general_purpose::URL_SAFE_NO_PAD.encode(mock_header); + let encoded_payload = general_purpose::URL_SAFE_NO_PAD.encode(mock_payload); + let mock_signature = general_purpose::URL_SAFE_NO_PAD.encode("mock_sig"); + + let mock_jwt = format!("{}.{}.{}", encoded_header, encoded_payload, mock_signature); + Ok(mock_jwt.into_bytes()) +} + +#[rpc(server, namespace = "openpassport")] +pub trait TestRpc { + #[method(name = "health")] + async fn health(&self) -> ResponsePayload<'static, String>; + + #[method(name = "hello")] + async fn hello( + &self, + user_pubkey: Vec, + uuid: uuid::Uuid, + supported_suites: Vec, + ) -> ResponsePayload<'static, HelloResponse>; + + #[method(name = "key_exchange")] + async fn key_exchange( + &self, + uuid: uuid::Uuid, + kyber_ciphertext: Vec, + ) -> ResponsePayload<'static, String>; + + /// DEBUG ONLY: Returns the derived session key for testing. + /// DO NOT use in production, as keys should never be exposed! + #[method(name = "debug_get_session_key")] + async fn debug_get_session_key(&self, uuid: uuid::Uuid) -> ResponsePayload<'static, Vec>; +} + +pub struct TestRpcServerImpl { + store: Arc, +} + +impl TestRpcServerImpl { + pub fn new(store: Arc) -> Self { + Self { store } + } +} + +#[async_trait] +impl TestRpcServer for TestRpcServerImpl { + async fn health(&self) -> ResponsePayload<'static, String> { + ResponsePayload::success("OK".to_string()) + } + + async fn hello( + &self, + user_pubkey: Vec, + uuid: uuid::Uuid, + supported_suites: Vec, + ) -> ResponsePayload<'static, HelloResponse> { + println!("Received hello from UUID: {}", uuid); + println!("Supported suites: {:?}", supported_suites); + + // negotiating suite + let selected_suite = if supported_suites.contains(&"Self-PQXDH-1".to_string()) { + "Self-PQXDH-1" + } else if supported_suites.contains(&"legacy-p256".to_string()) { + "legacy-p256" + } else { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidParams.code(), + "No supported cryptographic suite found", + None, + )); + }; + + println!("Selected suite: {}", selected_suite); + + if selected_suite == "Self-PQXDH-1" { + // PQXDH flow + if user_pubkey.len() != 32 { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + format!("X25519 public key must be 32 bytes, got {}", user_pubkey.len()), + None, + )); + } + + // generating X25519 keypair + let mut rng = OsRng; + let x25519_secret = X25519Secret::random_from_rng(&mut rng); + let x25519_public = X25519PublicKey::from(&x25519_secret); + + // parsing client's X25519 public key + let client_x25519_public = { + let mut key_bytes = [0u8; 32]; + key_bytes.copy_from_slice(&user_pubkey); + X25519PublicKey::from(key_bytes) + }; + + // computing X25519 shared secret + let x25519_shared = x25519_secret.diffie_hellman(&client_x25519_public); + + // generating Kyber ML-KEM-768 keypair (using system RNG) + let (decapsulation_key, encapsulation_key) = MlKem768::generate(&mut rand::rng()); + + // storing pending state + let key_material = KeyMaterial::PqxdhPending { + x25519_shared: x25519_shared.as_bytes().to_vec(), + kyber_secret: decapsulation_key.as_bytes().to_vec(), + }; + + match self.store.insert_new_agreement(uuid, key_material).await { + Ok(_) => println!("Stored pending PQXDH state for UUID: {}", uuid), + Err(e) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + format!("UUID already exists: {}", e), + None, + )); + } + } + + // encoding public keys + let x25519_b64 = general_purpose::STANDARD.encode(x25519_public.as_bytes()); + let kyber_b64 = general_purpose::STANDARD.encode(encapsulation_key.as_bytes()); + + // generating mock attestation + let attestation_result = + get_mock_attestation(vec![&x25519_b64, &kyber_b64, selected_suite]) + .await + .map_err(|e| format!("{:?}", e)); + + let attestation = match attestation_result { + Ok(att) => att, + Err(e) => { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("Attestation failed: {}", e), + None, + )); + } + }; + + let encaps_key_bytes = encapsulation_key.as_bytes(); + println!( + "Returning PQXDH response: X25519 pubkey {} bytes, Kyber pubkey {} bytes", + x25519_public.as_bytes().len(), + encaps_key_bytes.len() + ); + println!("Server Kyber public key (first 16 bytes): {:02x?}", &encaps_key_bytes[..16]); + + ResponsePayload::success( + HelloResponse::new( + uuid, + attestation, + selected_suite.to_string(), + Some(x25519_public.as_bytes().to_vec()), + Some(encapsulation_key.as_bytes().to_vec()), + ) + .into(), + ) + } else { + // legacy P-256 flow + if user_pubkey.len() != 33 { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "P-256 public key must be 33 bytes", + None, + )); + } + + let mut rng = OsRng; + let my_private_key = EphemeralSecret::random(&mut rng); + let my_public_key = PublicKey::from(&my_private_key); + + let their_public_key = match PublicKey::from_sec1_bytes(&user_pubkey) { + Ok(pubkey) => pubkey, + Err(e) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidParams.code(), + format!("{:?}", e), + None, + )); + } + }; + + let their_key_b64 = general_purpose::STANDARD + .encode(&their_public_key.to_encoded_point(true).to_bytes()); + let my_key_b64 = + general_purpose::STANDARD.encode(&my_public_key.to_encoded_point(true).to_bytes()); + + let attestation_result = get_mock_attestation(vec![&their_key_b64, &my_key_b64]) + .await + .map_err(|e| format!("{:?}", e)); + + let attestation = match attestation_result { + Ok(att) => att, + Err(e) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("Attestation failed: {}", e), + None, + )); + } + }; + + let shared_secret = my_private_key + .diffie_hellman(&their_public_key) + .raw_secret_bytes() + .to_vec(); + + let key_material = KeyMaterial::LegacyP256(shared_secret); + + match self.store.insert_new_agreement(uuid, key_material).await { + Ok(_) => println!("Stored legacy P-256 key for UUID: {}", uuid), + Err(e) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + format!("UUID already exists: {}", e), + None, + )); + } + } + + ResponsePayload::success( + HelloResponse::new(uuid, attestation, selected_suite.to_string(), None, None) + .into(), + ) + } + } + + async fn key_exchange( + &self, + uuid: uuid::Uuid, + kyber_ciphertext: Vec, + ) -> ResponsePayload<'static, String> { + println!("Received key_exchange from UUID: {}", uuid); + println!("Kyber ciphertext length: {}", kyber_ciphertext.len()); + + // validating ciphertext length + const ML_KEM_768_CIPHERTEXT_SIZE: usize = 1088; + if kyber_ciphertext.len() != ML_KEM_768_CIPHERTEXT_SIZE { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + format!( + "Invalid Kyber ciphertext: expected {} bytes, got {}", + ML_KEM_768_CIPHERTEXT_SIZE, + kyber_ciphertext.len() + ), + None, + )); + } + + // retrieving pending state + let key_material = match self.store.get_key_material(&uuid).await { + Some(m) => m, + None => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID not found", + None, + )); + } + }; + + let (x25519_shared, kyber_secret) = match key_material { + KeyMaterial::PqxdhPending { + x25519_shared, + kyber_secret, + } => (x25519_shared, kyber_secret), + _ => { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID is not in PQXDH pending state", + None, + )); + } + }; + + // decapsulating + type DecapKey = ::DecapsulationKey; + let decaps_key_bytes: &Encoded = kyber_secret[..].try_into().map_err(|_| { + "Invalid decapsulation key length" + }).unwrap(); + let decaps_key = DecapKey::from_bytes(decaps_key_bytes); + + // parsing ciphertext from bytes + let ct: ml_kem::Ciphertext = kyber_ciphertext[..].try_into().map_err(|_| { + "Invalid ciphertext length" + }).unwrap(); + + let kyber_shared = decaps_key.decapsulate(&ct).map_err(|e| { + format!("Decapsulation failed: {:?}", e) + }).unwrap(); + let kyber_shared_bytes: &[u8] = kyber_shared.as_ref(); + + println!("Server X25519 shared secret (first 8 bytes): {:02x?}", &x25519_shared[..8]); + println!("Server Kyber shared secret (first 8 bytes): {:02x?}", &kyber_shared_bytes[..8]); + + // deriving session key using Signal PQXDH spec + let f_prefix = vec![0xff; 32]; + let mut ikm = Vec::with_capacity(f_prefix.len() + x25519_shared.len() + kyber_shared_bytes.len()); + ikm.extend_from_slice(&f_prefix); + ikm.extend_from_slice(&x25519_shared); + ikm.extend_from_slice(&kyber_shared_bytes); + + let salt = vec![0u8; 32]; + let info = b"Self-PQXDH-1_X25519_SHA-256_ML-KEM-768"; + + let hkdf = Hkdf::::new(Some(&salt), &ikm); + let mut session_key = vec![0u8; 32]; + if let Err(e) = hkdf.expand(info, &mut session_key) { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("HKDF failed: {:?}", e), + None, + )); + } + + println!("Derived session key (first 8 bytes): {:02x?}", &session_key[..8]); + + // updating store + let complete_material = KeyMaterial::PqxdhComplete(session_key); + if let Err(e) = self.store.update_key_material(&uuid, complete_material).await { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("Failed to update key material: {}", e), + None, + )); + } + + println!("✅ Key exchange complete for UUID: {}", uuid); + ResponsePayload::success("key_exchange_complete".to_string()) + } + + async fn debug_get_session_key(&self, uuid: uuid::Uuid) -> ResponsePayload<'static, Vec> { + match self.store.get_shared_secret(&uuid).await { + Some(key) => { + println!("DEBUG: Returning session key for UUID: {}", uuid); + ResponsePayload::success(key) + } + None => ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID not found or not in complete state", + None, + )), + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addr = "127.0.0.1:9944"; + println!("🚀 Starting PQXDH test server on {}", addr); + println!("📝 This server is for testing PQXDH handshake only"); + println!(" - Attestation is mocked (not cryptographically valid)"); + println!(" - No database required"); + println!(" - Only hello and key_exchange methods available\n"); + + let server = Server::builder().build(addr).await?; + let store = Arc::new(LruStore::new(1000)); + + let handle = server.start(TestRpcServerImpl::new(store).into_rpc()); + + println!("✅ Server ready at ws://{}\n", addr); + println!("Press Ctrl+C to stop\n"); + + handle.stopped().await; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..9cc0ebf --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,8 @@ +// Library exports for integration tests. +// This file exposes internal modules so they can be tested. + +pub mod db; +pub mod generator; +pub mod store; +pub mod types; +pub mod utils; diff --git a/src/server.rs b/src/server.rs index 9bffbb3..5669d90 100644 --- a/src/server.rs +++ b/src/server.rs @@ -12,21 +12,44 @@ use std::collections::HashMap; use std::sync::Arc; use crate::db::create_proof_status; -use crate::store::LruStore; +use crate::store::{KeyMaterial, LruStore}; use crate::types::{ProofRequest, SubmitRequest}; use crate::utils; use crate::{generator::file_generator::FileGenerator, types::HelloResponse}; +// PQXDH imports +use hkdf::Hkdf; +use ml_kem::kem::Decapsulate; +use ml_kem::{Encoded, EncodedSizeUser, KemCore, MlKem768}; +use sha2::Sha256; +use x25519_dalek::{EphemeralSecret as X25519Secret, PublicKey as X25519PublicKey}; + #[rpc(server, namespace = "openpassport")] pub trait Rpc { #[method(name = "health")] async fn health(&self) -> ResponsePayload<'static, String>; + + /// Initiates a cryptographic handshake with the client. + /// Negotiates suite (PQXDH or legacy P-256), generates keypairs, computes shared secrets, + /// and returns attestation with public keys. For PQXDH, stores pending state awaiting key_exchange. #[method(name = "hello")] async fn hello( &self, user_pubkey: Vec, uuid: uuid::Uuid, + supported_suites: Vec, ) -> ResponsePayload<'static, HelloResponse>; + + /// Completes the PQXDH handshake by decapsulating the Kyber ciphertext. + /// Derives the final session key using HKDF per Signal PQXDH spec. + /// Only valid for UUIDs in PqxdhPending state after hello. + #[method(name = "key_exchange")] + async fn key_exchange( + &self, + uuid: uuid::Uuid, + kyber_ciphertext: Vec, + ) -> ResponsePayload<'static, String>; + #[method(name = "submit_request")] async fn submit_request( &self, @@ -70,75 +93,280 @@ impl RpcServer for RpcServerImpl { &self, user_pubkey: Vec, uuid: uuid::Uuid, + supported_suites: Vec, ) -> ResponsePayload<'static, HelloResponse> { - if user_pubkey.len() != 33 { + // negotiating suite: prefer PQXDH, fallback to legacy + let selected_suite = if supported_suites.contains(&"Self-PQXDH-1".to_string()) { + "Self-PQXDH-1" + } else if supported_suites.contains(&"legacy-p256".to_string()) { + "legacy-p256" + } else { return ResponsePayload::error(ErrorObjectOwned::owned::( - types::ErrorCode::InvalidRequest.code(), //BAD_REQUEST - "Public key must be 33 bytes", + types::ErrorCode::InvalidParams.code(), + "No supported cryptographic suite found", None, )); }; - let mut rng = OsRng; - let my_private_key = EphemeralSecret::random(&mut rng); - let my_public_key = PublicKey::from(&my_private_key); - - let their_public_key = match PublicKey::from_sec1_bytes(&user_pubkey) { - Ok(pubkey) => pubkey, - Err(err) => { + if selected_suite == "Self-PQXDH-1" { + // PQXDH flow: X25519 + Kyber ML-KEM-768 + if user_pubkey.len() != 32 { return ResponsePayload::error(ErrorObjectOwned::owned::( - types::ErrorCode::InvalidParams.code(), //INVALID_PARAMS - format!("{:?}", err), + types::ErrorCode::InvalidRequest.code(), + "X25519 public key must be 32 bytes", None, )); } - }; - let their_public_key_compressed = - their_public_key.to_encoded_point(true).to_bytes().to_vec(); - let my_public_key_compressed = my_public_key.to_encoded_point(true).to_bytes().to_vec(); + // generating X25519 keypair + let mut rng = OsRng; + let x25519_secret = X25519Secret::random_from_rng(&mut rng); + let x25519_public = X25519PublicKey::from(&x25519_secret); - let their_public_key_string = - general_purpose::STANDARD.encode(&their_public_key_compressed); - let my_public_key_string = general_purpose::STANDARD.encode(&my_public_key_compressed); + // parsing client's X25519 public key + let client_x25519_public = { + let mut key_bytes = [0u8; 32]; + key_bytes.copy_from_slice(&user_pubkey); + X25519PublicKey::from(key_bytes) + }; - let attestation = match utils::attestation::get_custom_token_bytes(vec![ - &their_public_key_string, - &my_public_key_string, - ]) - .await - { - Ok(attestation) => attestation, - Err(err) => { + // computing X25519 shared secret + let x25519_shared = x25519_secret.diffie_hellman(&client_x25519_public); + + // generating Kyber ML-KEM-768 keypair (using system RNG) + let (decapsulation_key, encapsulation_key) = MlKem768::generate(&mut rand::rng()); + + // storing X25519 shared secret and Kyber secret key (waiting for key_exchange) + let key_material = KeyMaterial::PqxdhPending { + x25519_shared: x25519_shared.as_bytes().to_vec(), + kyber_secret: decapsulation_key.as_bytes().to_vec(), + }; + + match self.store.insert_new_agreement(uuid, key_material).await { + Ok(_) => (), + Err(_) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID already exists", + None, + )); + } + } + + // encoding public keys for attestation + let x25519_public_b64 = general_purpose::STANDARD.encode(x25519_public.as_bytes()); + let kyber_public_b64 = general_purpose::STANDARD.encode(encapsulation_key.as_bytes()); + + // creating attestation JWT with suite and public keys + let attestation_result = utils::attestation::get_custom_token_bytes(vec![ + &x25519_public_b64, + &kyber_public_b64, + selected_suite, + ]) + .await + .map_err(|e| format!("{:?}", e)); + + let attestation = match attestation_result { + Ok(attestation) => attestation, + Err(err_string) => { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + err_string, + None, + )); + } + }; + + ResponsePayload::success( + HelloResponse::new( + uuid, + attestation, + selected_suite.to_string(), + Some(x25519_public.as_bytes().to_vec()), + Some(encapsulation_key.as_bytes().to_vec()), + ) + .into(), + ) + } else { + // legacy P-256 ECDH flow + if user_pubkey.len() != 33 { return ResponsePayload::error(ErrorObjectOwned::owned::( - types::ErrorCode::InternalError.code(), //INTERNAL_SERVER_ERROR - format!("{:?}", err), + types::ErrorCode::InvalidRequest.code(), + "P-256 public key must be 33 bytes", None, )); } - }; - let derived_key_result = my_private_key - .diffie_hellman(&their_public_key) - .raw_secret_bytes() - .to_vec(); + let mut rng = OsRng; + let my_private_key = EphemeralSecret::random(&mut rng); + let my_public_key = PublicKey::from(&my_private_key); - match self - .store - .insert_new_agreement(uuid, derived_key_result) + let their_public_key = match PublicKey::from_sec1_bytes(&user_pubkey) { + Ok(pubkey) => pubkey, + Err(err) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidParams.code(), + format!("{:?}", err), + None, + )); + } + }; + + let their_public_key_compressed = + their_public_key.to_encoded_point(true).to_bytes().to_vec(); + let my_public_key_compressed = my_public_key.to_encoded_point(true).to_bytes().to_vec(); + + let their_public_key_string = + general_purpose::STANDARD.encode(&their_public_key_compressed); + let my_public_key_string = general_purpose::STANDARD.encode(&my_public_key_compressed); + + let attestation = match utils::attestation::get_custom_token_bytes(vec![ + &their_public_key_string, + &my_public_key_string, + ]) .await - { - Ok(_) => (), - Err(_) => { + { + Ok(attestation) => attestation, + Err(err) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("{:?}", err), + None, + )); + } + }; + + let derived_key_result = my_private_key + .diffie_hellman(&their_public_key) + .raw_secret_bytes() + .to_vec(); + + let key_material = KeyMaterial::LegacyP256(derived_key_result); + + match self.store.insert_new_agreement(uuid, key_material).await { + Ok(_) => (), + Err(_) => { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID already exists", + None, + )); + } + } + + ResponsePayload::success( + HelloResponse::new(uuid, attestation, selected_suite.to_string(), None, None) + .into(), + ) + } + } + + async fn key_exchange( + &self, + uuid: uuid::Uuid, + kyber_ciphertext: Vec, + ) -> ResponsePayload<'static, String> { + // validating Kyber ciphertext length (ML-KEM-768 = 1088 bytes) + const ML_KEM_768_CIPHERTEXT_SIZE: usize = 1088; + if kyber_ciphertext.len() != ML_KEM_768_CIPHERTEXT_SIZE { + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + format!( + "Invalid Kyber ciphertext length: expected {}, got {}", + ML_KEM_768_CIPHERTEXT_SIZE, + kyber_ciphertext.len() + ), + None, + )); + } + + // retrieving pending PQXDH key material from store + let key_material = match self.store.get_key_material(&uuid).await { + Some(material) => material, + None => { return ResponsePayload::error(ErrorObjectOwned::owned::( - types::ErrorCode::InvalidRequest.code(), //INTERNAL_SERVER_ERROR - "UUID already exists", + types::ErrorCode::InvalidRequest.code(), + "UUID not found", None, )); } + }; + + let (x25519_shared, kyber_secret) = match key_material { + KeyMaterial::PqxdhPending { + x25519_shared, + kyber_secret, + } => (x25519_shared, kyber_secret), + _ => { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InvalidRequest.code(), + "UUID is not in PQXDH pending state", + None, + )); + } + }; + + // decapsulating Kyber ciphertext to get Kyber shared secret + // converting stored secret key bytes back to DecapsulationKey + type DecapKey = ::DecapsulationKey; + let decaps_key_bytes: &Encoded = kyber_secret[..].try_into().map_err(|_| { + "Invalid decapsulation key length" + }).unwrap(); + let decaps_key = DecapKey::from_bytes(decaps_key_bytes); + + // converting ciphertext bytes to Ciphertext type + let ct: ml_kem::Ciphertext = kyber_ciphertext[..].try_into().map_err(|_| { + "Invalid ciphertext length" + }).unwrap(); + + // decapsulating to get shared secret + let kyber_shared = decaps_key.decapsulate(&ct).map_err(|e| { + format!("Decapsulation failed: {:?}", e) + }).unwrap(); + let kyber_shared_bytes: &[u8] = kyber_shared.as_ref(); + + // deriving final session key using HKDF matching Signal PQXDH spec + // F prefix (32 0xFF bytes) per Signal spec + let f_prefix = vec![0xff; 32]; + + // IKM = F || X25519_shared || Kyber_shared + let mut ikm = Vec::with_capacity(f_prefix.len() + x25519_shared.len() + kyber_shared_bytes.len()); + ikm.extend_from_slice(&f_prefix); + ikm.extend_from_slice(&x25519_shared); + ikm.extend_from_slice(&kyber_shared_bytes); + + // zero-filled salt (32 bytes for SHA-256 output length) per Signal spec + let salt = vec![0u8; 32]; + + // info parameter following Signal pattern: "protocol_curve_hash_pqkem" + let info = b"Self-PQXDH-1_X25519_SHA-256_ML-KEM-768"; + + // deriving 32-byte session key using HKDF-SHA256 + let hkdf = Hkdf::::new(Some(&salt), &ikm); + let mut session_key = vec![0u8; 32]; + if let Err(e) = hkdf.expand(info, &mut session_key) { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("HKDF expansion failed: {:?}", e), + None, + )); + } + + // updating store with final session key + let final_material = KeyMaterial::PqxdhComplete(session_key); + if let Err(e) = self.store.update_key_material(&uuid, final_material).await { + self.store.remove_agreement(&uuid).await; + return ResponsePayload::error(ErrorObjectOwned::owned::( + types::ErrorCode::InternalError.code(), + format!("Failed to update key material: {}", e), + None, + )); } - ResponsePayload::success(HelloResponse::new(uuid, attestation).into()) + ResponsePayload::success("key_exchange_complete".to_string()) } //TODO: check if circuit exists diff --git a/src/store.rs b/src/store.rs index d03dbe6..74cd64e 100644 --- a/src/store.rs +++ b/src/store.rs @@ -4,8 +4,23 @@ use std::num::NonZeroUsize; use lru::LruCache; use tokio::sync::Mutex; +/// Key material stored in the LRU cache during handshake. +#[derive(Clone)] +pub enum KeyMaterial { + /// Legacy P-256 ECDH shared secret (final session key). + LegacyP256(Vec), + /// PQXDH pending state: X25519 shared secret and Kyber secret key. + /// Waiting for client to send Kyber ciphertext in key_exchange call. + PqxdhPending { + x25519_shared: Vec, + kyber_secret: Vec, + }, + /// PQXDH complete: final session key derived via HKDF. + PqxdhComplete(Vec), +} + pub struct LruStore { - ecdh_store: Mutex>>, + ecdh_store: Mutex>, } impl LruStore { @@ -17,27 +32,61 @@ impl LruStore { } impl LruStore { + /// Inserts new key material for a UUID, returning an error if UUID already exists. + /// Used during initial handshake to store either legacy P-256 keys or PQXDH pending state. pub async fn insert_new_agreement( &self, uuid: uuid::Uuid, - shared_secret: Vec, + key_material: KeyMaterial, ) -> Result<(), String> { let mut cache = self.ecdh_store.lock().await; if cache.contains(&uuid.to_string()) { return Err("Duplicate uuid".to_string()); } else { - cache.put(uuid.to_string(), shared_secret); + cache.put(uuid.to_string(), key_material); } return Ok(()); } - pub async fn get_shared_secret(&self, uuid: &uuid::Uuid) -> Option> { + /// Retrieves key material for a UUID without removing it from the store. + /// Returns None if UUID not found. Used to check handshake state. + pub async fn get_key_material(&self, uuid: &uuid::Uuid) -> Option { let mut cache = self.ecdh_store.lock().await; cache.get(&uuid.to_string()).map(|x| x.clone()) } + /// Updates existing key material for a UUID, typically transitioning from pending to complete state. + /// Returns an error if UUID not found. Used during PQXDH key_exchange to finalize session key. + pub async fn update_key_material( + &self, + uuid: &uuid::Uuid, + key_material: KeyMaterial, + ) -> Result<(), String> { + let mut cache = self.ecdh_store.lock().await; + if cache.contains(&uuid.to_string()) { + cache.put(uuid.to_string(), key_material); + Ok(()) + } else { + Err("UUID not found".to_string()) + } + } + + /// Retrieves the final session key for a UUID if handshake is complete. + /// Returns None if UUID not found or still in pending state. + /// Used by submit_request to decrypt client payloads. + pub async fn get_shared_secret(&self, uuid: &uuid::Uuid) -> Option> { + let mut cache = self.ecdh_store.lock().await; + cache.get(&uuid.to_string()).and_then(|material| match material { + KeyMaterial::LegacyP256(key) => Some(key.clone()), + KeyMaterial::PqxdhComplete(key) => Some(key.clone()), + KeyMaterial::PqxdhPending { .. } => None, + }) + } + + /// Removes key material for a UUID from the store. + /// Used for cleanup on errors or after proof submission. pub async fn remove_agreement(&self, uuid: &uuid::Uuid) { let mut cache = self.ecdh_store.lock().await; cache.pop(&uuid.to_string()); diff --git a/src/types.rs b/src/types.rs index fc2bec0..221ab87 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,15 +3,39 @@ use serde::{Deserialize, Serialize}; use crate::generator::Circuit; +/// Response to the initial handshake containing attestation and cryptographic suite information. +/// For PQXDH handshakes, includes X25519 and Kyber public keys for post-quantum security. +/// For legacy P-256 handshakes, the PQXDH fields are omitted. #[derive(Serialize, Clone)] pub struct HelloResponse { uuid: uuid::Uuid, attestation: Vec, + /// Selected cryptographic suite: "Self-PQXDH-1" or "legacy-p256" + selected_suite: String, + /// Server's X25519 public key (32 bytes) for PQXDH handshakes + #[serde(skip_serializing_if = "Option::is_none")] + x25519_pubkey: Option>, + /// Server's Kyber ML-KEM-768 encapsulation key (1184 bytes) for PQXDH handshakes + #[serde(skip_serializing_if = "Option::is_none")] + kyber_pubkey: Option>, } impl HelloResponse { - pub fn new(uuid: uuid::Uuid, attestation: Vec) -> Self { - HelloResponse { uuid, attestation } + /// Creates a new HelloResponse with the specified cryptographic suite and optional PQXDH keys. + pub fn new( + uuid: uuid::Uuid, + attestation: Vec, + selected_suite: String, + x25519_pubkey: Option>, + kyber_pubkey: Option>, + ) -> Self { + HelloResponse { + uuid, + attestation, + selected_suite, + x25519_pubkey, + kyber_pubkey, + } } } diff --git a/src/utils.rs b/src/utils.rs index d387768..1ae971d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -41,11 +41,19 @@ pub async fn cleanup(uuid: uuid::Uuid, pool: &sqlx::Pool, reason pub mod attestation { use std::error::Error; + #[cfg(feature = "test_mode")] + use base64::{engine::general_purpose, Engine}; + + #[cfg(not(feature = "test_mode"))] use hyper::body::Buf; + #[cfg(not(feature = "test_mode"))] use hyper::{Body, Client, Request}; + #[cfg(not(feature = "test_mode"))] use hyperlocal::{UnixClientExt, Uri as HyperlocalUri}; + #[cfg(not(feature = "test_mode"))] use serde::Serialize; + #[cfg(not(feature = "test_mode"))] #[derive(Serialize)] struct TokenRequest<'a> { audience: &'a str, @@ -53,6 +61,35 @@ pub mod attestation { nonces: Vec<&'a str>, } + /// Generates a mock attestation token for local testing (test_mode feature only). + /// Returns a JWT-like structure with base64url-encoded header, payload, and signature. + /// NOT cryptographically valid - for development and testing purposes only. + #[cfg(feature = "test_mode")] + pub async fn get_custom_token_bytes(nonces: Vec<&str>) -> Result, Box> { + // creating mock JWT header + let mock_header = r#"{"alg":"RS256","typ":"JWT"}"#; + + // creating mock JWT payload with nonces and timestamps + let mock_payload = format!( + r#"{{"nonces":[{}],"iat":{},"exp":{}}}"#, + nonces.iter().map(|n| format!("\"{}\"", n)).collect::>().join(","), + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs(), + std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs() + 3600, + ); + + // encoding header and payload using base64url (not crypto-safe, just for structure) + let encoded_header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(mock_header); + let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(mock_payload); + let mock_signature = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("mock_signature_for_testing"); + + // assembling JWT structure: header.payload.signature + let mock_jwt = format!("{}.{}.{}", encoded_header, encoded_payload, mock_signature); + + println!("Mock attestation token generated for nonces: {:?}", nonces); + Ok(mock_jwt.into_bytes()) + } + + #[cfg(not(feature = "test_mode"))] pub async fn get_custom_token_bytes(nonces: Vec<&str>) -> Result, Box> { let request_body = TokenRequest { audience: "USER", diff --git a/tests/store_tests.rs b/tests/store_tests.rs new file mode 100644 index 0000000..1edac7d --- /dev/null +++ b/tests/store_tests.rs @@ -0,0 +1,226 @@ +use tee_server::store::{KeyMaterial, LruStore}; +use uuid::Uuid; + +/// Tests basic LruStore insert and retrieve operations with legacy P-256 key material. +/// Verifies that stored shared secrets can be successfully retrieved. +#[tokio::test] +async fn test_store_insert_and_retrieve_legacy() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + let shared_secret = vec![0x42; 32]; + + let material = KeyMaterial::LegacyP256(shared_secret.clone()); + + // inserting key material + store + .insert_new_agreement(uuid, material) + .await + .expect("Insert should succeed"); + + // retrieving shared secret + let retrieved = store.get_shared_secret(&uuid).await; + assert!(retrieved.is_some(), "Should retrieve shared secret"); + assert_eq!(retrieved.unwrap(), shared_secret, "Shared secrets should match"); +} + +/// Tests LruStore with PQXDH pending state insertion and retrieval. +/// Verifies that pending state does not return session keys until completion. +#[tokio::test] +async fn test_store_pqxdh_pending() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + + let x25519_shared = vec![0x01; 32]; + let kyber_secret = vec![0x02; 2400]; + + let material = KeyMaterial::PqxdhPending { + x25519_shared: x25519_shared.clone(), + kyber_secret: kyber_secret.clone(), + }; + + // inserting pending material + store + .insert_new_agreement(uuid, material) + .await + .expect("Insert should succeed"); + + // get_shared_secret should return None for pending state + let retrieved = store.get_shared_secret(&uuid).await; + assert!(retrieved.is_none(), "Should not return shared secret for pending state"); + + // get_key_material should return the pending state + let material_retrieved = store.get_key_material(&uuid).await; + assert!(material_retrieved.is_some(), "Should retrieve key material"); + + match material_retrieved.unwrap() { + KeyMaterial::PqxdhPending { + x25519_shared: x, + kyber_secret: k, + } => { + assert_eq!(x, x25519_shared, "X25519 shared secrets should match"); + assert_eq!(k, kyber_secret, "Kyber secret keys should match"); + } + _ => panic!("Expected PqxdhPending state"), + } +} + +/// Tests updating PQXDH state from pending to complete via update_key_material. +/// Simulates the transition that occurs during key_exchange RPC call. +#[tokio::test] +async fn test_store_pqxdh_update_to_complete() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + + // inserting pending state + let material = KeyMaterial::PqxdhPending { + x25519_shared: vec![0x01; 32], + kyber_secret: vec![0x02; 2400], + }; + + store + .insert_new_agreement(uuid, material) + .await + .expect("Insert should succeed"); + + // updating to complete state + let session_key = vec![0x99; 32]; + let complete_material = KeyMaterial::PqxdhComplete(session_key.clone()); + + store + .update_key_material(&uuid, complete_material) + .await + .expect("Update should succeed"); + + // verifying update + let retrieved = store.get_shared_secret(&uuid).await; + assert!(retrieved.is_some(), "Should retrieve shared secret after completion"); + assert_eq!(retrieved.unwrap(), session_key, "Session keys should match"); +} + +/// Tests that duplicate UUID insertion fails with an error. +/// Ensures each UUID can only be used once per handshake session. +#[tokio::test] +async fn test_store_duplicate_uuid_fails() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + + let material1 = KeyMaterial::LegacyP256(vec![0x01; 32]); + let material2 = KeyMaterial::LegacyP256(vec![0x02; 32]); + + // first insert should succeed + store + .insert_new_agreement(uuid, material1) + .await + .expect("First insert should succeed"); + + // second insert with same UUID should fail + let result = store.insert_new_agreement(uuid, material2).await; + assert!(result.is_err(), "Duplicate UUID should fail"); +} + +/// Tests updating non-existent UUID fails with an error. +/// Prevents invalid state transitions for uninitialized sessions. +#[tokio::test] +async fn test_store_update_nonexistent_fails() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + + let material = KeyMaterial::PqxdhComplete(vec![0x99; 32]); + + // updating non-existent UUID should fail + let result = store.update_key_material(&uuid, material).await; + assert!(result.is_err(), "Update of non-existent UUID should fail"); +} + +/// Tests removing key material from store via remove_agreement. +/// Used for cleanup on errors or after successful proof submission. +#[tokio::test] +async fn test_store_remove_agreement() { + let store = LruStore::new(100); + let uuid = Uuid::new_v4(); + + let material = KeyMaterial::LegacyP256(vec![0x42; 32]); + + store + .insert_new_agreement(uuid, material) + .await + .expect("Insert should succeed"); + + // verifying key material exists + assert!(store.get_key_material(&uuid).await.is_some(), "Key material should exist"); + + // removing key material + store.remove_agreement(&uuid).await; + + // verifying key material is gone + assert!(store.get_key_material(&uuid).await.is_none(), "Key material should be removed"); +} + +/// Tests LRU eviction policy when cache reaches capacity. +/// Verifies that least recently used entries are evicted when cache is full. +#[tokio::test] +async fn test_store_lru_eviction() { + let store = LruStore::new(2); // small cache size + let uuid1 = Uuid::new_v4(); + let uuid2 = Uuid::new_v4(); + let uuid3 = Uuid::new_v4(); + + let material1 = KeyMaterial::LegacyP256(vec![0x01; 32]); + let material2 = KeyMaterial::LegacyP256(vec![0x02; 32]); + let material3 = KeyMaterial::LegacyP256(vec![0x03; 32]); + + // inserting first two entries + store.insert_new_agreement(uuid1, material1).await.unwrap(); + store.insert_new_agreement(uuid2, material2).await.unwrap(); + + // both should be retrievable + assert!(store.get_key_material(&uuid1).await.is_some()); + assert!(store.get_key_material(&uuid2).await.is_some()); + + // inserting third entry should evict the least recently used (uuid1) + store.insert_new_agreement(uuid3, material3).await.unwrap(); + + // uuid1 should be evicted + assert!(store.get_key_material(&uuid1).await.is_none(), "First UUID should be evicted"); + assert!(store.get_key_material(&uuid2).await.is_some(), "Second UUID should still exist"); + assert!(store.get_key_material(&uuid3).await.is_some(), "Third UUID should exist"); +} + +/// Tests that get_shared_secret returns correct values for all KeyMaterial variants. +/// Legacy and Complete return keys, Pending returns None until handshake completes. +#[tokio::test] +async fn test_store_get_shared_secret_variants() { + let store = LruStore::new(100); + + // legacy P-256 should return the shared secret + let uuid1 = Uuid::new_v4(); + let legacy_key = vec![0x42; 32]; + store + .insert_new_agreement(uuid1, KeyMaterial::LegacyP256(legacy_key.clone())) + .await + .unwrap(); + assert_eq!(store.get_shared_secret(&uuid1).await.unwrap(), legacy_key); + + // PQXDH pending should return None + let uuid2 = Uuid::new_v4(); + store + .insert_new_agreement( + uuid2, + KeyMaterial::PqxdhPending { + x25519_shared: vec![0x01; 32], + kyber_secret: vec![0x02; 2400], + }, + ) + .await + .unwrap(); + assert!(store.get_shared_secret(&uuid2).await.is_none()); + + // PQXDH complete should return the session key + let uuid3 = Uuid::new_v4(); + let session_key = vec![0x99; 32]; + store + .insert_new_agreement(uuid3, KeyMaterial::PqxdhComplete(session_key.clone())) + .await + .unwrap(); + assert_eq!(store.get_shared_secret(&uuid3).await.unwrap(), session_key); +}