diff --git a/crates/daphne/src/messages/taskprov.rs b/crates/daphne/src/messages/taskprov.rs index 166c0fc71..ebeb18477 100644 --- a/crates/daphne/src/messages/taskprov.rs +++ b/crates/daphne/src/messages/taskprov.rs @@ -556,7 +556,7 @@ impl TaskprovAdvertisement { ) })?; - if compute_task_id(taskprov_data.as_ref()) != *task_id { + if compute_task_id(version, taskprov_data.as_ref()) != *task_id { // Return unrecognizedTask following section 5.1 of the taskprov draft. return Err(DapAbort::UnrecognizedTask { task_id: *task_id }); } @@ -578,7 +578,7 @@ impl TaskprovAdvertisement { #[cfg(any(test, feature = "test-utils"))] pub fn compute_task_id(&self, version: DapVersion) -> TaskId { - compute_task_id(&self.get_encoded_with_param(&version).unwrap()) + compute_task_id(version, &self.get_encoded_with_param(&version).unwrap()) } } diff --git a/crates/daphne/src/taskprov.rs b/crates/daphne/src/taskprov.rs index 247df18ab..1e41c0ae7 100644 --- a/crates/daphne/src/taskprov.rs +++ b/crates/daphne/src/taskprov.rs @@ -30,19 +30,27 @@ use ring::{ use serde::{Deserialize, Serialize}; use url::Url; -/// SHA-256 of "dap-taskprov" -pub(crate) const TASKPROV_SALT: [u8; 32] = [ - 0x28, 0xb9, 0xbb, 0x4f, 0x62, 0x4f, 0x67, 0x9a, 0xc1, 0x98, 0xd9, 0x68, 0xf4, 0xb0, 0x9e, 0xec, - 0x74, 0x01, 0x7a, 0x52, 0xcb, 0x4c, 0xf6, 0x39, 0xfb, 0x83, 0xe0, 0x47, 0x72, 0x3a, 0x0f, 0xfe, +/// SHA-256 of `b"dap-taskprov"`. +const TASKPROV_SALT: [u8; 32] = [ + 40, 185, 187, 79, 98, 79, 103, 154, 193, 152, 217, 104, 244, 176, 158, 236, 116, 1, 122, 82, + 203, 76, 246, 57, 251, 131, 224, 71, 114, 58, 15, 254, +]; + +/// SHA-256 of `b"dap-takprov task id"`. +const TASKPROV_TASK_ID_SALT: [u8; 32] = [ + 70, 13, 237, 116, 40, 100, 135, 190, 152, 104, 104, 209, 157, 184, 219, 27, 5, 132, 88, 56, + 228, 214, 41, 30, 241, 91, 110, 32, 82, 11, 220, 130, ]; /// Compute the task id of a serialized task config. -pub(crate) fn compute_task_id(serialized: &[u8]) -> TaskId { - let d = digest::digest(&digest::SHA256, serialized); - let dref = d.as_ref(); - let mut b: [u8; 32] = [0; 32]; - b[..32].copy_from_slice(&dref[..32]); - TaskId(b) +pub(crate) fn compute_task_id(version: DapVersion, taskprov_advertisemnt_bytes: &[u8]) -> TaskId { + let mut hash = ring::digest::Context::new(&digest::SHA256); + if version == DapVersion::Latest { + hash.update(&TASKPROV_TASK_ID_SALT); + } + hash.update(taskprov_advertisemnt_bytes); + let digest = hash.finish(); + TaskId(digest.as_ref().try_into().unwrap()) } // The documentation for ring::hkdf says computing the Salt is expensive, and we use the same PRK all the @@ -537,6 +545,7 @@ mod test { }; let task_id = compute_task_id( + version, &taskprov_advertisemnt .get_encoded_with_param(&version) .unwrap(), @@ -589,6 +598,75 @@ mod test { test_versions! { check_vdaf_key_computation } + #[test] + fn check_task_id_draft09() { + let taskprov_advertisemnt_bytes = messages::taskprov::TaskprovAdvertisement { + task_info: "cool task".as_bytes().to_vec(), + leader_url: messages::taskprov::UrlBytes { + bytes: b"https://leader.com/".to_vec(), + }, + helper_url: messages::taskprov::UrlBytes { + bytes: b"http://helper.org:8788/".to_vec(), + }, + time_precision: 3600, + min_batch_size: 1, + query_config: messages::taskprov::QueryConfig::LeaderSelected { + draft09_max_batch_size: Some(NonZeroU32::new(2).unwrap()), + }, + lifetime: messages::taskprov::TaskLifetime::Draft09 { expiration: 23 }, + vdaf_config: messages::taskprov::VdafConfig::Prio2 { dimension: 10 }, + extensions: Vec::new(), + draft09_max_batch_query_count: Some(23), + draft09_dp_config: Some(messages::taskprov::DpConfig::None), + } + .get_encoded_with_param(&DapVersion::Draft09) + .unwrap(); + + let expected_task_id = TaskId([ + 142, 26, 248, 229, 126, 249, 222, 59, 10, 221, 34, 151, 27, 60, 28, 0, 134, 194, 142, + 84, 167, 128, 139, 140, 98, 35, 119, 117, 109, 108, 125, 211, + ]); + let task_id = compute_task_id(DapVersion::Latest, &taskprov_advertisemnt_bytes); + println!("{:?}", task_id.0); + assert_eq!(task_id, expected_task_id); + } + + #[test] + fn check_task_id() { + let taskprov_advertisemnt_bytes = messages::taskprov::TaskprovAdvertisement { + task_info: "cool task".as_bytes().to_vec(), + leader_url: messages::taskprov::UrlBytes { + bytes: b"https://leader.com/".to_vec(), + }, + helper_url: messages::taskprov::UrlBytes { + bytes: b"http://helper.org:8788/".to_vec(), + }, + time_precision: 3600, + min_batch_size: 1, + query_config: messages::taskprov::QueryConfig::LeaderSelected { + draft09_max_batch_size: None, + }, + lifetime: messages::taskprov::TaskLifetime::Latest { + start: 23, + duration: 23, + }, + vdaf_config: messages::taskprov::VdafConfig::Prio2 { dimension: 10 }, + extensions: Vec::new(), + draft09_max_batch_query_count: None, + draft09_dp_config: None, + } + .get_encoded_with_param(&DapVersion::Latest) + .unwrap(); + + let expected_task_id = TaskId([ + 29, 66, 37, 142, 99, 73, 46, 14, 193, 147, 230, 204, 154, 75, 129, 177, 55, 2, 228, 62, + 227, 204, 248, 200, 120, 251, 5, 161, 203, 149, 72, 55, + ]); + let task_id = compute_task_id(DapVersion::Latest, &taskprov_advertisemnt_bytes); + println!("{:?}", task_id.0); + assert_eq!(task_id, expected_task_id); + } + fn resolve_advertised_task_config_expect_abort_unrecognized_vdaf(version: DapVersion) { // Create a request for a taskprov task with an unrecognized VDAF. let (req, task_id) = { @@ -626,6 +704,7 @@ mod test { }; let task_id = { compute_task_id( + version, &taskprov_advertisement .get_encoded_with_param(&version) .unwrap(), @@ -698,6 +777,7 @@ mod test { }; let task_id = { compute_task_id( + version, &taskprov_advertisement .get_encoded_with_param(&version) .unwrap(),