diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index 549c94ef0..3cf636fea 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -91,7 +91,7 @@ use url::Url; use vdaf::mastic::MasticWeight; pub use messages::request::{DapRequest, DapRequestMeta, DapResponse}; -pub use protocol::aggregator::InitializedReport; +pub use protocol::report_init::{InitializedReport, WithPeerPrepShare}; /// DAP version used for a task. #[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)] diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index e106bd0b8..b3c928aa9 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -1,8 +1,12 @@ // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +use super::{ + no_duplicates, + report_init::{InitializedReport, WithPeerPrepShare}, +}; #[cfg(feature = "experimental")] -use crate::vdaf::mastic::{mastic_prep_finish, mastic_prep_finish_from_shares, mastic_prep_init}; +use crate::vdaf::mastic::{mastic_prep_finish, mastic_prep_finish_from_shares}; use crate::{ constants::DapAggregatorRole, error::DapAbort, @@ -10,248 +14,27 @@ use crate::{ hpke::{info_and_aad, HpkeConfig, HpkeDecrypter}, messages::{ self, encode_u32_bytes, AggregationJobInitReq, AggregationJobResp, Base64Encode, - BatchSelector, Extension, HpkeCiphertext, PartialBatchSelector, PlaintextInputShare, - PrepareInit, Report, ReportId, ReportMetadata, ReportShare, TaskId, Transition, - TransitionFailure, TransitionVar, + BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, Report, ReportId, + ReportShare, TaskId, Transition, TransitionFailure, TransitionVar, }, metrics::{DaphneMetrics, ReportStatus}, + protocol::{decode_ping_pong_framed, PingPongMessageType}, vdaf::{ - prio2::{prio2_prep_finish, prio2_prep_finish_from_shares, prio2_prep_init}, - prio3::{prio3_prep_finish, prio3_prep_finish_from_shares, prio3_prep_init}, - VdafError, VdafPrepShare, VdafPrepState, + prio2::{prio2_prep_finish, prio2_prep_finish_from_shares}, + prio3::{prio3_prep_finish, prio3_prep_finish_from_shares}, + VdafError, }, AggregationJobReportState, DapAggregateShare, DapAggregateSpan, DapAggregationJobState, DapAggregationParam, DapError, DapTaskConfig, DapVersion, VdafConfig, }; -use prio::codec::{ - encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode, ParameterizedEncode, -}; +use prio::codec::{encode_u32_items, Encode, ParameterizedDecode, ParameterizedEncode}; +use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use std::{ collections::{HashMap, HashSet}, - io::Cursor, iter::zip, ops::Range, }; -use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; - -// Ping-pong message framing as defined in draft-irtf-cfrg-vdaf-08, Section 5.8. We do not -// implement the "continue" message type because we only support 1-round VDAFs. -enum PingPongMessageType { - Initialize = 0, - Finish = 2, -} - -// This is essentially a re-implementation of a method in the `messages` module. However the goal -// here is to make it zero-copy. See https://github.com/cloudflare/daphne/issues/15. -fn decode_ping_pong_framed( - bytes: &[u8], - expected_type: PingPongMessageType, -) -> Result<&[u8], CodecError> { - let mut r = Cursor::new(bytes); - - let message_type = u8::decode(&mut r)?; - if message_type != expected_type as u8 { - return Err(CodecError::UnexpectedValue); - } - - let message_len = u32::decode(&mut r)?.try_into().unwrap(); - let message_start = usize::try_from(r.position()).unwrap(); - if bytes.len() - message_start < message_len { - return Err(CodecError::LengthPrefixTooBig(message_len)); - } - if bytes.len() - message_start > message_len { - return Err(CodecError::BytesLeftOver(message_len)); - } - - Ok(&bytes[message_start..]) -} - -/// Report state during aggregation initialization after the VDAF preparation step. -#[expect(clippy::large_enum_variant)] -#[derive(Clone)] -#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, deepsize::DeepSizeOf))] -pub enum InitializedReport { - Ready { - metadata: ReportMetadata, - public_share: Vec, - prep_share: VdafPrepShare, - prep_state: VdafPrepState, - // Set by the Helper. - peer_prep_share: Option>, - }, - Rejected { - metadata: ReportMetadata, - failure: TransitionFailure, - }, -} - -impl InitializedReport { - #[expect(clippy::too_many_arguments)] - pub(crate) fn new( - decrypter: &impl HpkeDecrypter, - valid_report_range: Range, - role: DapAggregatorRole, - task_id: &TaskId, - task_config: &DapTaskConfig, - report_share: ReportShare, - prep_init_payload: Option>, - // We need to use this variable for Mastic, which is currently fenced by the - // "experimental" feature. - #[cfg_attr(not(feature = "experimental"), expect(unused_variables))] - agg_param: &DapAggregationParam, - ) -> Result { - macro_rules! reject { - ($failure:ident) => { - return Ok(Self::Rejected { - metadata: report_share.report_metadata, - failure: TransitionFailure::$failure, - }) - }; - } - match report_share.report_metadata.time { - t if t >= task_config.not_after => reject!(TaskExpired), - t if t < valid_report_range.start => reject!(ReportDropped), - t if valid_report_range.end < t => reject!(ReportTooEarly), - _ => {} - } - - // decrypt input share - let PlaintextInputShare { - extensions, - payload: input_share, - } = { - let info = info_and_aad::InputShare { - version: task_config.version, - receiver: role, - task_id, - report_metadata: &report_share.report_metadata, - public_share: &report_share.public_share, - }; - - let encoded_input_share = - match decrypter.hpke_decrypt(info, &report_share.encrypted_input_share) { - Ok(encoded_input_share) => encoded_input_share, - Err(DapError::Transition(failure)) => { - return Ok(Self::Rejected { - metadata: report_share.report_metadata, - failure, - }) - } - Err(e) => return Err(e), - }; - - let Ok(plaintext) = PlaintextInputShare::get_decoded_with_param( - &task_config.version, - &encoded_input_share, - ) else { - reject!(InvalidMessage) - }; - plaintext - }; - - // Handle report extensions. - { - if no_duplicates(extensions.iter().map(|e| e.type_code())).is_err() { - reject!(InvalidMessage) - } - let mut taskprov_indicated = false; - for extension in extensions { - match extension { - Extension::Taskprov { .. } if task_config.method_is_taskprov() => { - taskprov_indicated = true; - } - - // Reject reports with unrecognized extensions. - _ => reject!(InvalidMessage), - } - } - - if task_config.method_is_taskprov() && !taskprov_indicated { - // taskprov: If the task configuration method is taskprov, then we expect each - // report to indicate support. - reject!(InvalidMessage); - } - } - - // Decode the ping-pong "initialize" message framing. - // (draft-irtf-cfrg-vdaf-08, Section 5.8). - let peer_prep_share = match prep_init_payload - .as_ref() - .map(|payload| decode_ping_pong_framed(payload, PingPongMessageType::Initialize)) - .transpose() - { - Ok(peer_prep_share) => peer_prep_share, - Err(e) => { - tracing::warn!(error = ?e, "rejecting report"); - reject!(VdafPrepError); - } - }; - - let agg_id = match role { - DapAggregatorRole::Leader => 0, - DapAggregatorRole::Helper => 1, - }; - let res = match &task_config.vdaf { - VdafConfig::Prio3(ref prio3_config) => prio3_prep_init( - prio3_config, - &task_config.vdaf_verify_key, - agg_id, - &report_share.report_metadata.id.0, - &report_share.public_share, - &input_share, - ), - VdafConfig::Prio2 { dimension } => prio2_prep_init( - *dimension, - &task_config.vdaf_verify_key, - agg_id, - &report_share.report_metadata.id.0, - &report_share.public_share, - &input_share, - ), - #[cfg(feature = "experimental")] - VdafConfig::Mastic { - input_size, - weight_config, - } => mastic_prep_init( - *input_size, - *weight_config, - &task_config.vdaf_verify_key, - agg_param, - &report_share.public_share, - input_share.as_ref(), - ), - VdafConfig::Pine(pine) => pine.prep_init( - &task_config.vdaf_verify_key, - agg_id, - &report_share.report_metadata.id.0, - &report_share.public_share, - &input_share, - ), - }; - - Ok(match res { - Ok((prep_state, prep_share)) => Self::Ready { - metadata: report_share.report_metadata, - public_share: report_share.public_share, - peer_prep_share: peer_prep_share.map(|p| p.to_vec()), - prep_share, - prep_state, - }, - Err(e) => { - tracing::warn!(error = ?e, "rejecting report"); - reject!(VdafPrepError); - } - }) - } - - pub(crate) fn metadata(&self) -> &ReportMetadata { - match self { - Self::Ready { metadata, .. } | Self::Rejected { metadata, .. } => metadata, - } - } -} - pub(crate) enum ReportProcessedStatus { /// The report should be marked as aggregated. However it has already been committed to /// storage, so don't do so again. @@ -343,10 +126,9 @@ impl DapTaskConfig { let [leader_share, helper_share] = report.encrypted_input_shares; - let initialized_report = InitializedReport::new( + let initialized_report = InitializedReport::from_client( &decrypter, valid_report_time_range.clone(), - DapAggregatorRole::Leader, task_id, self, ReportShare { @@ -354,14 +136,13 @@ impl DapTaskConfig { public_share: report.public_share, encrypted_input_share: leader_share, }, - None, agg_param, )?; match initialized_report { InitializedReport::Ready { metadata, public_share, - peer_prep_share: _, + peer_prep_share: (), prep_share, prep_state, } => { @@ -454,7 +235,7 @@ impl DapTaskConfig { task_id: &TaskId, agg_job_init_req: AggregationJobInitReq, replay_protection: ReplayProtection, - ) -> Result, DapError> + ) -> Result>, DapError> where H: HpkeDecrypter + Sync, { @@ -478,14 +259,13 @@ impl DapTaskConfig { .prep_inits .into_par_iter() .map(|prep_init| { - InitializedReport::new( + InitializedReport::from_leader( decrypter, valid_report_time_range.clone(), - DapAggregatorRole::Helper, task_id, self, prep_init.report_share, - Some(prep_init.payload), + prep_init.payload, &agg_param, ) }) @@ -499,7 +279,7 @@ impl DapTaskConfig { &self, report_status: &HashMap, part_batch_sel: &PartialBatchSelector, - initialized_reports: &[InitializedReport], + initialized_reports: &[InitializedReport], ) -> Result<(DapAggregateSpan, AggregationJobResp), DapError> { let num_reports = initialized_reports.len(); let mut agg_span = DapAggregateSpan::default(); @@ -513,7 +293,7 @@ impl DapTaskConfig { InitializedReport::Ready { metadata, public_share: _, - peer_prep_share: Some(leader_prep_share), + peer_prep_share: leader_prep_share, prep_share: helper_prep_share, prep_state: helper_prep_state, } => { @@ -581,11 +361,6 @@ impl DapTaskConfig { } } - InitializedReport::Ready { - peer_prep_share: None, - .. - } => return Err(fatal_error!(err = "expected leader prep share, got none")), - InitializedReport::Rejected { metadata: _, failure, @@ -769,21 +544,3 @@ fn produce_encrypted_agg_share( hpke_config.encrypt(info, &agg_share_data) } - -/// checks if an iterator has no duplicate items, returns the ok if there are no dups or an error -/// with the first offending item. -fn no_duplicates(iterator: I) -> Result<(), I::Item> -where - I: Iterator, - I::Item: Eq + std::hash::Hash + Copy, -{ - let (lower, upper) = iterator.size_hint(); - let mut seen = HashSet::with_capacity(upper.unwrap_or(lower)); - - for item in iterator { - if !seen.insert(item) { - return Err(item); - } - } - Ok(()) -} diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 7bfff8578..6c504f956 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -1,15 +1,69 @@ // Copyright (c) 2024 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause +use prio::codec::{CodecError, Decode as _}; +use std::{collections::HashSet, io::Cursor}; + pub(crate) mod aggregator; mod client; mod collector; +pub(crate) mod report_init; + +/// checks if an iterator has no duplicate items, returns the ok if there are no dups or an error +/// with the first offending item. +fn no_duplicates(iterator: I) -> Result<(), I::Item> +where + I: Iterator, + I::Item: Eq + std::hash::Hash, +{ + let (lower, upper) = iterator.size_hint(); + let mut seen = HashSet::with_capacity(upper.unwrap_or(lower)); + + for item in iterator { + if let Some(repeat) = seen.replace(item) { + return Err(repeat); + } + } + Ok(()) +} + +// Ping-pong message framing as defined in draft-irtf-cfrg-vdaf-08, Section 5.8. We do not +// implement the "continue" message type because we only support 1-round VDAFs. +enum PingPongMessageType { + Initialize = 0, + Finish = 2, +} + +// This is essentially a re-implementation of a method in the `messages` module. However the goal +// here is to make it zero-copy. See https://github.com/cloudflare/daphne/issues/15. +fn decode_ping_pong_framed( + bytes: &[u8], + expected_type: PingPongMessageType, +) -> Result<&[u8], CodecError> { + let mut r = Cursor::new(bytes); + + let message_type = u8::decode(&mut r)?; + if message_type != expected_type as u8 { + return Err(CodecError::UnexpectedValue); + } + + let message_len = u32::decode(&mut r)?.try_into().unwrap(); + let message_start = usize::try_from(r.position()).unwrap(); + if bytes.len() - message_start < message_len { + return Err(CodecError::LengthPrefixTooBig(message_len)); + } + if bytes.len() - message_start > message_len { + return Err(CodecError::BytesLeftOver(message_len)); + } + + Ok(&bytes[message_start..]) +} #[cfg(test)] mod test { + use super::{report_init::InitializedReport, PingPongMessageType}; use crate::{ assert_metrics_include, - constants::DapAggregatorRole, error::DapAbort, hpke::{HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId}, messages::{ @@ -17,7 +71,6 @@ mod test { PrepareInit, Report, ReportId, ReportShare, Transition, TransitionFailure, TransitionVar, }, - protocol::aggregator::InitializedReport, test_versions, testing::AggregationJobTest, vdaf::{Prio3Config, VdafConfig}, @@ -27,6 +80,7 @@ mod test { use assert_matches::assert_matches; use hpke_rs::HpkePublicKey; use prio::{ + codec::encode_u32_items, field::Field64, vdaf::{ prio3::Prio3, AggregateShare, Aggregator as VdafAggregator, Collector as VdafCollector, @@ -58,10 +112,9 @@ mod test { prep_share: leader_prep_share, prep_state: leader_prep_state, .. - } = InitializedReport::new( + } = InitializedReport::from_client( &t.leader_hpke_receiver_config, t.valid_report_time_range(), - DapAggregatorRole::Leader, &t.task_id, &t.task_config, ReportShare { @@ -69,7 +122,6 @@ mod test { public_share: report.public_share.clone(), encrypted_input_share: leader_share, }, - None, &DapAggregationParam::Empty, ) .unwrap() @@ -81,10 +133,9 @@ mod test { prep_share: helper_prep_share, prep_state: helper_prep_state, .. - } = InitializedReport::new( + } = InitializedReport::from_leader( &t.helper_hpke_receiver_config, t.valid_report_time_range(), - DapAggregatorRole::Helper, &t.task_id, &t.task_config, ReportShare { @@ -92,7 +143,12 @@ mod test { public_share: report.public_share, encrypted_input_share: helper_share, }, - None, + { + let mut outbound = Vec::new(); + outbound.push(PingPongMessageType::Initialize as u8); + encode_u32_items(&mut outbound, &version, &[leader_prep_share.clone()]).unwrap(); + outbound + }, &DapAggregationParam::Empty, ) .unwrap() @@ -682,10 +738,9 @@ mod test { let [leader_share, _] = report.encrypted_input_shares; let report_metadata = report.report_metadata.clone(); - let initialized_report = InitializedReport::new( + let initialized_report = InitializedReport::from_client( &t.leader_hpke_receiver_config, t.valid_report_time_range(), - DapAggregatorRole::Leader, &t.task_id, &t.task_config, ReportShare { @@ -693,7 +748,6 @@ mod test { public_share: report.public_share, encrypted_input_share: leader_share, }, - None, &DapAggregationParam::Empty, ) .unwrap(); @@ -732,10 +786,9 @@ mod test { let report_metadata = report.report_metadata.clone(); let [leader_share, _] = report.encrypted_input_shares; - let initialized_report = InitializedReport::new( + let initialized_report = InitializedReport::from_client( &t.leader_hpke_receiver_config, t.valid_report_time_range(), - DapAggregatorRole::Leader, &t.task_id, &t.task_config, ReportShare { @@ -743,7 +796,6 @@ mod test { public_share: report.public_share, encrypted_input_share: leader_share, }, - None, &DapAggregationParam::Empty, ) .unwrap(); diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs new file mode 100644 index 000000000..e6578d3c6 --- /dev/null +++ b/crates/daphne/src/protocol/report_init.rs @@ -0,0 +1,282 @@ +// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +#[cfg(feature = "experimental")] +use crate::vdaf::mastic::mastic_prep_init; +use crate::{ + constants::DapAggregatorRole, + hpke::{info_and_aad, HpkeDecrypter}, + messages::{ + self, Extension, PlaintextInputShare, ReportMetadata, ReportShare, TaskId, + TransitionFailure, + }, + protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType}, + vdaf::{ + prio2::prio2_prep_init, prio3::prio3_prep_init, VdafConfig, VdafPrepShare, VdafPrepState, + }, + DapAggregationParam, DapError, DapTaskConfig, +}; +use prio::codec::{CodecError, ParameterizedDecode as _}; +use std::ops::{Deref, Range}; + +/// Report state during aggregation initialization after the VDAF preparation step. +/// +/// The `Peer` parameter can be: +/// - `()` if the report came from a client. +/// - [`WithPeerPrepShare`] if the report came from the leader. +#[expect(clippy::large_enum_variant)] +#[derive(Clone)] +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, deepsize::DeepSizeOf))] +pub enum InitializedReport { + Ready { + metadata: ReportMetadata, + public_share: Vec, + prep_share: VdafPrepShare, + prep_state: VdafPrepState, + peer_prep_share: Peer, + }, + Rejected { + metadata: ReportMetadata, + failure: TransitionFailure, + }, +} + +pub struct WithPeerPrepShare(Vec); + +impl Deref for WithPeerPrepShare { + type Target = Vec; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl InitializedReport<()> { + pub fn from_client( + decrypter: &impl HpkeDecrypter, + valid_report_range: Range, + task_id: &TaskId, + task_config: &DapTaskConfig, + report_share: ReportShare, + agg_param: &DapAggregationParam, + ) -> Result { + Self::initialize( + decrypter, + valid_report_range, + task_id, + task_config, + report_share, + (), + agg_param, + ) + } +} + +impl InitializedReport { + pub fn from_leader( + decrypter: &impl HpkeDecrypter, + valid_report_range: Range, + task_id: &TaskId, + task_config: &DapTaskConfig, + report_share: ReportShare, + prep_init_payload: Vec, + agg_param: &DapAggregationParam, + ) -> Result { + Self::initialize( + decrypter, + valid_report_range, + task_id, + task_config, + report_share, + prep_init_payload, + agg_param, + ) + } +} + +impl

InitializedReport

{ + fn initialize( + decrypter: &impl HpkeDecrypter, + valid_report_range: Range, + task_id: &TaskId, + task_config: &DapTaskConfig, + report_share: ReportShare, + prep_init_payload: S, + // We need to use this variable for Mastic, which is currently fenced by the + // "experimental" feature. + #[cfg_attr(not(feature = "experimental"), expect(unused_variables))] + agg_param: &DapAggregationParam, + ) -> Result + where + S: PrepInitPayload, + { + macro_rules! reject { + ($failure:ident) => { + return Ok(InitializedReport::Rejected { + metadata: report_share.report_metadata, + failure: TransitionFailure::$failure, + }) + }; + } + match report_share.report_metadata.time { + t if t >= task_config.not_after => reject!(TaskExpired), + t if t < valid_report_range.start => reject!(ReportDropped), + t if valid_report_range.end < t => reject!(ReportTooEarly), + _ => {} + } + + // decrypt input share + let PlaintextInputShare { + extensions, + payload: input_share, + } = { + let info = info_and_aad::InputShare { + version: task_config.version, + receiver: S::ROLE, + task_id, + report_metadata: &report_share.report_metadata, + public_share: &report_share.public_share, + }; + + let encoded_input_share = + match decrypter.hpke_decrypt(info, &report_share.encrypted_input_share) { + Ok(encoded_input_share) => encoded_input_share, + Err(DapError::Transition(failure)) => { + return Ok(InitializedReport::Rejected { + metadata: report_share.report_metadata, + failure, + }) + } + Err(e) => return Err(e), + }; + + let Ok(plaintext) = PlaintextInputShare::get_decoded_with_param( + &task_config.version, + &encoded_input_share, + ) else { + reject!(InvalidMessage) + }; + plaintext + }; + + // Handle report extensions. + { + if no_duplicates(extensions.iter().map(|e| e.type_code())).is_err() { + reject!(InvalidMessage) + } + let mut taskprov_indicated = false; + for extension in extensions { + match extension { + Extension::Taskprov { .. } if task_config.method_is_taskprov() => { + taskprov_indicated = true; + } + + // Reject reports with unrecognized extensions. + _ => reject!(InvalidMessage), + } + } + + if task_config.method_is_taskprov() && !taskprov_indicated { + // taskprov: If the task configuration method is taskprov, then we expect each + // report to indicate support. + reject!(InvalidMessage); + } + } + + // Decode the ping-pong "initialize" message framing. + // (draft-irtf-cfrg-vdaf-08, Section 5.8). + let peer_prep_share = match prep_init_payload.decode_ping_pong_framed() { + Ok(peer_prep_share) => peer_prep_share, + Err(e) => { + tracing::warn!(error = ?e, "rejecting report"); + reject!(VdafPrepError); + } + }; + + let agg_id = match S::ROLE { + DapAggregatorRole::Leader => 0, + DapAggregatorRole::Helper => 1, + }; + let res = match &task_config.vdaf { + VdafConfig::Prio3(ref prio3_config) => prio3_prep_init( + prio3_config, + &task_config.vdaf_verify_key, + agg_id, + &report_share.report_metadata.id.0, + &report_share.public_share, + &input_share, + ), + VdafConfig::Prio2 { dimension } => prio2_prep_init( + *dimension, + &task_config.vdaf_verify_key, + agg_id, + &report_share.report_metadata.id.0, + &report_share.public_share, + &input_share, + ), + #[cfg(feature = "experimental")] + VdafConfig::Mastic { + input_size, + weight_config, + } => mastic_prep_init( + *input_size, + *weight_config, + &task_config.vdaf_verify_key, + agg_param, + &report_share.public_share, + input_share.as_ref(), + ), + VdafConfig::Pine(pine) => pine.prep_init( + &task_config.vdaf_verify_key, + agg_id, + &report_share.report_metadata.id.0, + &report_share.public_share, + &input_share, + ), + }; + + match res { + Ok((prep_state, prep_share)) => Ok(InitializedReport::Ready { + metadata: report_share.report_metadata, + public_share: report_share.public_share, + peer_prep_share, + prep_share, + prep_state, + }), + Err(e) => { + tracing::warn!(error = ?e, "rejecting report"); + reject!(VdafPrepError); + } + } + } + + pub(crate) fn metadata(&self) -> &ReportMetadata { + match self { + Self::Ready { metadata, .. } | Self::Rejected { metadata, .. } => metadata, + } + } +} + +/// This trait's purpose is to permit sharing the initialization logic of reports from clients and +/// from leaders, by generically implemeting the only part that's different. +trait PrepInitPayload { + type Decoded; + const ROLE: DapAggregatorRole; + fn decode_ping_pong_framed(&self) -> Result; +} + +impl PrepInitPayload for () { + type Decoded = (); + const ROLE: DapAggregatorRole = DapAggregatorRole::Leader; + fn decode_ping_pong_framed(&self) -> Result { + Ok(()) + } +} + +impl PrepInitPayload for Vec { + type Decoded = WithPeerPrepShare; + const ROLE: DapAggregatorRole = DapAggregatorRole::Helper; + fn decode_ping_pong_framed(&self) -> Result { + decode_ping_pong_framed(self, PingPongMessageType::Initialize) + .map(|b| WithPeerPrepShare(b.to_vec())) + } +} diff --git a/crates/daphne/src/roles/helper.rs b/crates/daphne/src/roles/helper.rs index bb548af60..f72aa42f1 100644 --- a/crates/daphne/src/roles/helper.rs +++ b/crates/daphne/src/roles/helper.rs @@ -15,7 +15,10 @@ use crate::{ AggregationJobResp, PartialBatchSelector, TaskId, TransitionFailure, TransitionVar, }, metrics::{DaphneMetrics, DaphneRequestType, ReportStatus}, - protocol::aggregator::{InitializedReport, ReplayProtection, ReportProcessedStatus}, + protocol::{ + aggregator::{ReplayProtection, ReportProcessedStatus}, + report_init::{InitializedReport, WithPeerPrepShare}, + }, roles::aggregator::MergeAggShareError, DapAggregationParam, DapError, DapRequest, DapResponse, DapTaskConfig, }; @@ -208,7 +211,7 @@ async fn finish_agg_job_and_aggregate( task_id: &TaskId, task_config: &DapTaskConfig, part_batch_sel: &PartialBatchSelector, - initialized_reports: &[InitializedReport], + initialized_reports: &[InitializedReport], metrics: &dyn DaphneMetrics, ) -> Result { // This loop is intended to run at most once on the "happy path". The intent is as follows: