diff --git a/daphne/benches/aggregation.rs b/daphne/benches/aggregation.rs index 70f9d01b5..b59abeede 100644 --- a/daphne/benches/aggregation.rs +++ b/daphne/benches/aggregation.rs @@ -3,8 +3,8 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use daphne::{ - hpke::HpkeKemId, testing::AggregationJobTest, DapLeaderTransition, DapMeasurement, DapVersion, - Prio3Config, VdafConfig, + hpke::HpkeKemId, testing::AggregationJobTest, DapLeaderAggregationJobTransition, + DapMeasurement, DapVersion, Prio3Config, VdafConfig, }; fn handle_agg_job_init_req(c: &mut Criterion) { @@ -38,7 +38,7 @@ fn handle_agg_job_init_req(c: &mut Criterion) { let agg_job_init_req = rt.block_on(async { let reports = agg_job_test.produce_reports(vec![measurement; batch_size]); - let DapLeaderTransition::Continue(_leader_state, agg_job_init_req) = + let DapLeaderAggregationJobTransition::Continued(_leader_state, agg_job_init_req) = agg_job_test.produce_agg_job_init_req(reports).await else { panic!("unexpected transition"); @@ -46,7 +46,7 @@ fn handle_agg_job_init_req(c: &mut Criterion) { agg_job_init_req }); - c.bench_function(&format!("handle_agg_job_init_req {vdaf:?}"), |b| { + c.bench_function(&format!("handle_agg_job_init_req {vdaf}"), |b| { b.to_async(&rt).iter(|| async { black_box(agg_job_test.handle_agg_job_init_req(&agg_job_init_req)).await }) diff --git a/daphne/src/lib.rs b/daphne/src/lib.rs index 18e0d250d..9dacc93c2 100644 --- a/daphne/src/lib.rs +++ b/daphne/src/lib.rs @@ -535,40 +535,53 @@ pub enum DapAggregateResult { U128Vec(Vec), } -/// The Leader's state after sending an AggregateInitReq. -#[derive(Debug)] -pub struct DapLeaderState { - pub(crate) seq: Vec<(VdafPrepState, VdafPrepMessage, Time, ReportId)>, +#[derive(Clone)] +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, deepsize::DeepSizeOf))] +pub(crate) struct AggregationJobReportState { + // draft02 compatibility: The Leader does not transmit its prep share. + draft02_prep_share: Option, + prep_state: VdafPrepState, + time: Time, + report_id: ReportId, +} + +/// Aggregator state during an aggregation job. +#[derive(Clone)] +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, deepsize::DeepSizeOf))] +pub struct DapAggregationJobState { + pub(crate) seq: Vec, part_batch_sel: PartialBatchSelector, } -/// The Leader's state after sending an AggregateContReq. +/// Leader state once during an aggregation job in which it has computed the output shares but is +/// waiting for the Helper's response before it commits them. #[derive(Debug)] -pub struct DapLeaderUncommitted { - pub(crate) seq: Vec<(DapOutputShare, ReportId)>, +pub struct DapAggregationJobUncommitted { + pub(crate) seq: Vec, part_batch_sel: PartialBatchSelector, } -/// The Helper's state during the aggregation flow. -#[derive(Clone, Debug, PartialEq)] -#[cfg_attr(any(test, feature = "test-utils"), derive(deepsize::DeepSizeOf))] -pub struct DapHelperState { - pub(crate) part_batch_sel: PartialBatchSelector, - pub(crate) seq: Vec<(VdafPrepState, Time, ReportId)>, -} - -impl Encode for DapHelperState { +impl Encode for DapAggregationJobState { fn encode(&self, bytes: &mut Vec) { self.part_batch_sel.encode(bytes); - for (state, time, report_id) in self.seq.iter() { - state.encode(bytes); - time.encode(bytes); - report_id.encode(bytes); + for report_state in self.seq.iter() { + if report_state.draft02_prep_share.is_some() { + // draft02 compatibility: The prep share is kept in this data structure for + // backwards compatibility. It's only used by the Leader, so we don't ever expect + // to encode it. + panic!("Tried to encode DapAggregationJobState with leader prep share"); + } + report_state.prep_state.encode(bytes); + report_state.time.encode(bytes); + report_state.report_id.encode(bytes); } } } -impl DapHelperState { +// TODO(cjpatton) Consider replacing this with an implementation of +// `ParameterizedDecode`. This requires changing the wire format to make the sequence +// of report states length-prefixed. +impl DapAggregationJobState { /// Decode the Helper state from a byte string. pub fn get_decoded(vdaf_config: &VdafConfig, data: &[u8]) -> Result { let mut r = std::io::Cursor::new(data); @@ -576,15 +589,20 @@ impl DapHelperState { .map_err(|e| DapAbort::from_codec_error(e, None))?; let mut seq = vec![]; while (r.position() as usize) < data.len() { - let state = VdafPrepState::decode_with_param(&(vdaf_config, false), &mut r) + let prep_state = VdafPrepState::decode_with_param(&(vdaf_config, false), &mut r) .map_err(|e| DapAbort::from_codec_error(e, None))?; let time = Time::decode(&mut r).map_err(|e| DapAbort::from_codec_error(e, None))?; let report_id = ReportId::decode(&mut r).map_err(|e| DapAbort::from_codec_error(e, None))?; - seq.push((state, time, report_id)) + seq.push(AggregationJobReportState { + draft02_prep_share: None, + prep_state, + time, + report_id, + }); } - Ok(DapHelperState { + Ok(Self { part_batch_sel, seq, }) @@ -694,32 +712,27 @@ impl DapAggregateShare { } } -/// Leader state transition during the aggregation flow. -#[derive(Debug)] -pub enum DapLeaderTransition { - /// The Leader has produced the next outbound message and its state has been updated. - Continue(DapLeaderState, M), +/// Leader state transitions during the aggregation flow. +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug))] +pub enum DapLeaderAggregationJobTransition { + /// Waiting for a response from the Helper. + Continued(DapAggregationJobState, M), - /// The leader has computed output shares, but is waiting on an AggregateResp from the hepler - /// before committing them. - Uncommitted(DapLeaderUncommitted, M), + /// Output shares computed, but waiting for a response from the Helper before committing. + Uncommitted(DapAggregationJobUncommitted, M), - /// The Leader has completed the aggregation flow without computing an aggregate share. - Skip, + /// Committed to the output shares. + Finished(Option), } -/// Helper state transition during the aggregation flow. -#[derive(Debug)] -pub enum DapHelperTransition { - /// The Helper has produced the next outbound message and its state has been updated. - Continue(DapHelperState, M), +/// Helper state transitions during the aggregation flow. +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug))] +pub enum DapHelperAggregationJobTransition { + /// Waiting for a response from the Leader. + Continued(DapAggregationJobState, M), - /// The Helper has produced the last outbound message and has computed a sequence of output - /// shares. - // - // TODO Instead of merging all output shares into a single aggregate share, return a collection - // of aggregat shares, each corresponding to a different batch interval. - Finish(Vec, M), + /// Committed to the output shares. + Finished(Option, M), } /// Specification of a concrete VDAF. diff --git a/daphne/src/messages/mod.rs b/daphne/src/messages/mod.rs index 80ae23f82..7311d4a0b 100644 --- a/daphne/src/messages/mod.rs +++ b/daphne/src/messages/mod.rs @@ -412,6 +412,45 @@ impl TryFrom for BatchSelector { } } +/// The PrepareInit message consisting of the report share and the Leader's initial prep share. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PrepareInit { + pub report_share: ReportShare, + pub draft05_payload: Option>, +} + +impl ParameterizedEncode for PrepareInit { + fn encode_with_param(&self, version: &DapVersion, bytes: &mut Vec) { + self.report_share.encode_with_param(version, bytes); + match (version, &self.draft05_payload) { + (DapVersion::Draft02, None) => (), + (DapVersion::Draft07, Some(payload)) => { + encode_u32_bytes(bytes, payload); + } + _ => unreachable!("unhandled version {version:?}"), + } + } +} + +impl ParameterizedDecode for PrepareInit { + fn decode_with_param( + version: &DapVersion, + bytes: &mut Cursor<&[u8]>, + ) -> Result { + let report_share = ReportShare::decode_with_param(version, bytes)?; + let draft05_payload = match version { + DapVersion::Draft02 => None, + DapVersion::Draft07 => Some(decode_u32_bytes(bytes)?), + _ => unreachable!("unhandled version {version:?}"), + }; + + Ok(Self { + report_share, + draft05_payload, + }) + } +} + /// Aggregate initialization request. #[derive(Clone, Debug, PartialEq, Eq)] pub struct AggregationJobInitReq { @@ -419,7 +458,7 @@ pub struct AggregationJobInitReq { pub draft02_agg_job_id: Option, // Set in draft02 pub agg_param: Vec, pub part_batch_sel: PartialBatchSelector, - pub report_shares: Vec, + pub prep_inits: Vec, } impl ParameterizedEncode for AggregationJobInitReq { @@ -440,7 +479,7 @@ impl ParameterizedEncode for AggregationJobInitReq { DapVersion::Unknown => unreachable!("unhandled version {version:?}"), }; self.part_batch_sel.encode(bytes); - encode_u32_items(bytes, version, &self.report_shares); + encode_u32_items(bytes, version, &self.prep_inits); } } @@ -464,7 +503,7 @@ impl ParameterizedDecode for AggregationJobInitReq { draft02_agg_job_id, agg_param, part_batch_sel: PartialBatchSelector::decode(bytes)?, - report_shares: decode_u32_items(version, bytes)?, + prep_inits: decode_u32_items(version, bytes)?, }) } } @@ -1293,8 +1332,56 @@ mod test { part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { batch_id: BatchId([0; 32]), }, - report_shares: vec![ - ReportShare { + prep_inits: vec![ + PrepareInit { + report_share: ReportShare { + report_metadata: ReportMetadata { + id: ReportId([99; 16]), + time: 1637361337, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 23, + enc: b"encapsulated key".to_vec(), + payload: b"ciphertext".to_vec(), + }, + }, + draft05_payload: None, + }, + PrepareInit { + report_share: ReportShare { + report_metadata: ReportMetadata { + id: ReportId([17; 16]), + time: 163736423, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 0, + enc: vec![], + payload: b"ciphertext".to_vec(), + }, + }, + draft05_payload: None, + }, + ], + }, + ); + } + + #[test] + fn roundtrip_agg_job_init_req() { + let want = AggregationJobInitReq { + draft02_task_id: Some(TaskId([23; 32])), + draft02_agg_job_id: Some(Draft02AggregationJobId([1; 32])), + agg_param: b"this is an aggregation parameter".to_vec(), + part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { + batch_id: BatchId([0; 32]), + }, + prep_inits: vec![ + PrepareInit { + report_share: ReportShare { report_metadata: ReportMetadata { id: ReportId([99; 16]), time: 1637361337, @@ -1307,7 +1394,10 @@ mod test { payload: b"ciphertext".to_vec(), }, }, - ReportShare { + draft05_payload: None, + }, + PrepareInit { + report_share: ReportShare { report_metadata: ReportMetadata { id: ReportId([17; 16]), time: 163736423, @@ -1320,46 +1410,7 @@ mod test { payload: b"ciphertext".to_vec(), }, }, - ], - }, - ); - } - - #[test] - fn roundtrip_agg_job_init_req() { - let want = AggregationJobInitReq { - draft02_task_id: Some(TaskId([23; 32])), - draft02_agg_job_id: Some(Draft02AggregationJobId([1; 32])), - agg_param: b"this is an aggregation parameter".to_vec(), - part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { - batch_id: BatchId([0; 32]), - }, - report_shares: vec![ - ReportShare { - report_metadata: ReportMetadata { - id: ReportId([99; 16]), - time: 1637361337, - extensions: Vec::default(), - }, - public_share: b"public share".to_vec(), - encrypted_input_share: HpkeCiphertext { - config_id: 23, - enc: b"encapsulated key".to_vec(), - payload: b"ciphertext".to_vec(), - }, - }, - ReportShare { - report_metadata: ReportMetadata { - id: ReportId([17; 16]), - time: 163736423, - extensions: Vec::default(), - }, - public_share: b"public share".to_vec(), - encrypted_input_share: HpkeCiphertext { - config_id: 0, - enc: vec![], - payload: b"ciphertext".to_vec(), - }, + draft05_payload: None, }, ], }; @@ -1378,32 +1429,38 @@ mod test { part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { batch_id: BatchId([0; 32]), }, - report_shares: vec![ - ReportShare { - report_metadata: ReportMetadata { - id: ReportId([99; 16]), - time: 1637361337, - extensions: Vec::default(), - }, - public_share: b"public share".to_vec(), - encrypted_input_share: HpkeCiphertext { - config_id: 23, - enc: b"encapsulated key".to_vec(), - payload: b"ciphertext".to_vec(), + prep_inits: vec![ + PrepareInit { + report_share: ReportShare { + report_metadata: ReportMetadata { + id: ReportId([99; 16]), + time: 1637361337, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 23, + enc: b"encapsulated key".to_vec(), + payload: b"ciphertext".to_vec(), + }, }, + draft05_payload: Some(b"prep share".to_vec()), }, - ReportShare { - report_metadata: ReportMetadata { - id: ReportId([17; 16]), - time: 163736423, - extensions: Vec::default(), - }, - public_share: b"public share".to_vec(), - encrypted_input_share: HpkeCiphertext { - config_id: 0, - enc: vec![], - payload: b"ciphertext".to_vec(), + PrepareInit { + report_share: ReportShare { + report_metadata: ReportMetadata { + id: ReportId([17; 16]), + time: 163736423, + extensions: Vec::default(), + }, + public_share: b"public share".to_vec(), + encrypted_input_share: HpkeCiphertext { + config_id: 0, + enc: vec![], + payload: b"ciphertext".to_vec(), + }, }, + draft05_payload: Some(b"prep share".to_vec()), }, ], }; diff --git a/daphne/src/metrics.rs b/daphne/src/metrics.rs index 60b993593..dff2372d0 100644 --- a/daphne/src/metrics.rs +++ b/daphne/src/metrics.rs @@ -23,7 +23,7 @@ pub struct DaphneMetrics { /// Helper: Number of records in an incoming AggregationJobInitReq. aggregation_job_batch_size_histogram: HistogramVec, - /// Helper: Number of times replays caused the aggregation to be retried. + /// Number of times `try_put_agg_share_span()` was retried. aggregation_job_commit_storage_retry_counter: IntCounterVec, } diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 23fdc71f8..07cf34281 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -13,15 +13,14 @@ use crate::{ audit_log::AggregationJobAuditAction, constants::DapMediaType, error::DapAbort, - fatal_error, messages::{ constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, Draft02AggregationJobId, PartialBatchSelector, ReportId, TaskId, }, metrics::{ContextualizedDaphneMetrics, DaphneRequestType}, - DapAggregateShareSpan, DapError, DapHelperState, DapHelperTransition, DapRequest, DapResource, - DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, + DapAggregateShareSpan, DapAggregationJobState, DapError, DapHelperAggregationJobTransition, + DapRequest, DapResource, DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, }; const COMMIT_REPORTS_TO_STORAGE_RETRIES: usize = 3; @@ -35,7 +34,7 @@ pub trait DapHelper: DapAggregator { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - helper_state: &DapHelperState, + helper_state: &DapAggregationJobState, ) -> Result; /// Fetch the Helper's aggregation-flow state. `None` is returned if the Helper has no state @@ -44,7 +43,7 @@ pub trait DapHelper: DapAggregator { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - ) -> Result, DapError>; + ) -> Result, DapError>; async fn handle_agg_job_init_req<'req>( &self, @@ -56,27 +55,32 @@ pub trait DapHelper: DapAggregator { AggregationJobInitReq::get_decoded_with_param(&req.version, &req.payload) .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; - metrics.agg_job_observe_batch_size(agg_job_init_req.report_shares.len()); + metrics.agg_job_observe_batch_size(agg_job_init_req.prep_inits.len()); // taskprov: Resolve the task config to use for the request. We also need to ensure // that all of the reports include the task config in the report extensions. (See // section 6 of draft-wang-ppm-dap-taskprov-02.) if let Some(taskprov_version) = self.get_global_config().taskprov_version { let using_taskprov = agg_job_init_req - .report_shares + .prep_inits .iter() - .filter(|share| share.report_metadata.is_taskprov(taskprov_version, task_id)) + .filter(|prep_init| { + prep_init + .report_share + .report_metadata + .is_taskprov(taskprov_version, task_id) + }) .count(); let first_metadata = match using_taskprov { 0 => None, - c if c == agg_job_init_req.report_shares.len() => { + c if c == agg_job_init_req.prep_inits.len() => { // All the extensions use taskprov and look ok, so compute first_metadata. // Note this will always be Some(). agg_job_init_req - .report_shares + .prep_inits .first() - .map(|report_share| &report_share.report_metadata) + .map(|prep_init| &prep_init.report_share.report_metadata) } _ => { // It's not all taskprov or no taskprov, so it's an error. @@ -126,6 +130,7 @@ pub trait DapHelper: DapAggregator { self, task_id, task_config, + |_id| false, &agg_job_init_req, &metrics, ) @@ -133,7 +138,7 @@ pub trait DapHelper: DapAggregator { .await?; let agg_job_resp = match transition { - DapHelperTransition::Continue(state, agg_job_resp) => { + DapHelperAggregationJobTransition::Continued(state, agg_job_resp) => { if !self .put_helper_state_if_not_exists(task_id, &agg_job_id, &state) .await? @@ -143,10 +148,60 @@ pub trait DapHelper: DapAggregator { "unexpected message for aggregation job (already exists)".into(), )); } + metrics.agg_job_started_inc(); agg_job_resp } - DapHelperTransition::Finish(..) => { - return Err(fatal_error!(err = "unexpected transition (finished)").into()); + DapHelperAggregationJobTransition::Finished( + mut maybe_agg_share_span, + mut agg_job_resp, + ) => { + // XXX This loop is messier than I'd like, but should be cleaned up after + // https://github.com/cloudflare/daphne/issues/408. + let mut replayed_reports = HashSet::new(); + let mut out_shares_count = 0; + for _ in 0..COMMIT_REPORTS_TO_STORAGE_RETRIES - 1 { + if let Some(agg_share_span) = maybe_agg_share_span { + out_shares_count = agg_share_span.report_count() as u64; + if let Some(replayed) = self + .try_put_agg_share_span(task_id, task_config, agg_share_span) + .await? + { + replayed_reports.extend(replayed); + metrics.agg_job_commit_storage_retry_inc(); + + let DapHelperAggregationJobTransition::Finished( + new_maybe_agg_share_span, + nwe_agg_job_resp, + ) = task_config + .vdaf + .handle_agg_job_init_req( + self, + self, + task_id, + task_config, + |id| replayed_reports.contains(id), + &agg_job_init_req, + &metrics, + ) + .map_err(DapError::Abort) + .await? + else { + todo!("XXX") + }; + + maybe_agg_share_span = new_maybe_agg_share_span; + agg_job_resp = nwe_agg_job_resp; + continue; + } + } + + break; + } + + metrics.report_inc_by("aggregated", out_shares_count); + metrics.agg_job_started_inc(); + metrics.agg_job_completed_inc(); + agg_job_resp } }; @@ -154,11 +209,10 @@ pub trait DapHelper: DapAggregator { req.host(), task_id, task_config, - agg_job_init_req.report_shares.len() as u64, + agg_job_init_req.prep_inits.len() as u64, AggregationJobAuditAction::Init, ); - metrics.agg_job_started_inc(); metrics.inbound_req_inc(DaphneRequestType::Aggregate); Ok(DapResponse { version: req.version, @@ -480,42 +534,33 @@ mod tests { use futures::StreamExt; use prio::codec::ParameterizedDecode; - use crate::messages::{ - AggregationJobInitReq, AggregationJobResp, ReportShare, Transition, TransitionVar, - }; + use crate::messages::{AggregationJobInitReq, AggregationJobResp, Transition, TransitionVar}; use crate::roles::DapHelper; use crate::MetaAggregationJobId; use crate::{roles::test::TestData, DapVersion}; #[tokio::test] - async fn replay_reports_when_continuing_aggregation() { + async fn replay_reports_when_continuing_aggregation_draft02() { let mut data = TestData::new(DapVersion::Draft02); let task_id = data.insert_task( DapVersion::Draft02, - crate::VdafConfig::Prio2 { dimension: 10 }, + crate::VdafConfig::Prio3(crate::Prio3Config::Count), ); let helper = data.new_helper(); let test = data.with_leader(Arc::clone(&helper)); - let report_shares = futures::stream::iter(0..3) - .then(|_| async { - let mut report = test.gen_test_report(&task_id).await; - ReportShare { - report_metadata: report.report_metadata, - public_share: report.public_share, - encrypted_input_share: report.encrypted_input_shares.remove(1), - } - }) + let reports = futures::stream::iter(0..3) + .then(|_| async { test.gen_test_report(&task_id).await }) .collect::>() .await; - let report_ids = report_shares + let report_ids = reports .iter() .map(|r| r.report_metadata.id.clone()) .collect::>(); let req = test - .gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, report_shares) + .gen_test_agg_job_init_req(&task_id, DapVersion::Draft02, reports) .await; let meta_agg_job_id = MetaAggregationJobId::Draft02(Cow::Owned( diff --git a/daphne/src/roles/leader.rs b/daphne/src/roles/leader.rs index 10af9633e..26e0f138e 100644 --- a/daphne/src/roles/leader.rs +++ b/daphne/src/roles/leader.rs @@ -18,8 +18,8 @@ use crate::{ CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId, }, metrics::DaphneRequestType, - DapCollectJob, DapError, DapLeaderProcessTelemetry, DapLeaderTransition, DapRequest, - DapResource, DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, + DapCollectJob, DapError, DapLeaderAggregationJobTransition, DapLeaderProcessTelemetry, + DapRequest, DapResource, DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, }; struct LeaderHttpRequestOptions<'p> { @@ -347,9 +347,12 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { ) .await?; let (state, agg_job_init_req) = match transition { - DapLeaderTransition::Continue(state, agg_job_init_req) => (state, agg_job_init_req), - DapLeaderTransition::Skip => return Ok(0), - DapLeaderTransition::Uncommitted(..) => { + DapLeaderAggregationJobTransition::Continued(state, agg_job_init_req) => { + (state, agg_job_init_req) + } + DapLeaderAggregationJobTransition::Finished(None) => return Ok(0), + DapLeaderAggregationJobTransition::Finished(Some(..)) + | DapLeaderAggregationJobTransition::Uncommitted(..) => { return Err(fatal_error!(err = "unexpected state transition (uncommitted)").into()) } }; @@ -386,61 +389,61 @@ pub trait DapLeader: DapAuthorizedSender + DapAggregator { let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload) .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; - // Prepare AggreagteContinueReq. + // Handle AggregationJobResp. let transition = task_config.vdaf.handle_agg_job_resp( task_id, + task_config, &agg_job_id, state, agg_job_resp, - task_config.version, &metrics, )?; - let (uncommited, agg_job_cont_req) = match transition { - DapLeaderTransition::Uncommitted(uncommited, agg_job_cont_req) => { - (uncommited, agg_job_cont_req) + let agg_share_span = match transition { + DapLeaderAggregationJobTransition::Uncommitted(uncommited, agg_job_cont_req) => { + // Send AggregationJobContinueReq and receive AggregationJobResp. + let resp = leader_send_http_request( + self, + task_id, + task_config, + LeaderHttpRequestOptions { + path: &url_path, + req_media_type: DapMediaType::AggregationJobContinueReq, + resp_media_type: DapMediaType::agg_job_cont_resp_for_version( + task_config.version, + ), + resource: agg_job_id.for_request_path(), + req_data: agg_job_cont_req.get_encoded_with_param(&task_config.version), + method: LeaderHttpRequestMethod::Post, + }, + ) + .await?; + let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload) + .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; + + // Handle AggregationJobResp. + task_config.vdaf.handle_final_agg_job_resp( + task_config, + uncommited, + agg_job_resp, + &metrics, + )? } - DapLeaderTransition::Skip => return Ok(0), - DapLeaderTransition::Continue(..) => { + DapLeaderAggregationJobTransition::Finished(Some(agg_share_span)) => agg_share_span, + DapLeaderAggregationJobTransition::Finished(None) => return Ok(0), + DapLeaderAggregationJobTransition::Continued(..) => { return Err(fatal_error!(err = "unexpected state transition (continue)").into()) } }; - // Send AggregationJobContinueReq and receive AggregationJobResp. - let resp = leader_send_http_request( - self, - task_id, - task_config, - LeaderHttpRequestOptions { - path: &url_path, - req_media_type: DapMediaType::AggregationJobContinueReq, - resp_media_type: DapMediaType::agg_job_cont_resp_for_version(task_config.version), - resource: agg_job_id.for_request_path(), - req_data: agg_job_cont_req.get_encoded_with_param(&task_config.version), - method: LeaderHttpRequestMethod::Post, - }, - ) - .await?; - let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload) - .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; - - // Commit the output shares. - let agg_share_span = task_config.vdaf.handle_final_agg_job_resp( - task_config, - uncommited, - agg_job_resp, - &metrics, - )?; let out_shares_count = agg_share_span.report_count() as u64; // At this point we're committed to aggregating the reports: if we do detect a report was // replayed at this stage, then we may end up with a batch mismatch. However, this should // only happen if there are multiple aggregation jobs in-flight that include the same // report. - let replayed = self .try_put_agg_share_span(task_id, task_config, agg_share_span) .await?; - if let Some(replayed) = replayed { tracing::warn!( replay_count = replayed.len(), diff --git a/daphne/src/roles/mod.rs b/daphne/src/roles/mod.rs index 9e2b45aa7..26ccd7300 100644 --- a/daphne/src/roles/mod.rs +++ b/daphne/src/roles/mod.rs @@ -206,15 +206,15 @@ mod test { taskprov, AggregateShareReq, AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq, Extension, Interval, PartialBatchSelector, Query, Report, ReportId, ReportMetadata, - ReportShare, TaskId, Time, Transition, TransitionFailure, TransitionVar, + TaskId, Time, Transition, TransitionFailure, TransitionVar, }, taskprov::TaskprovVersion, test_versions, testing::{AggStore, MockAggregator, MockAggregatorReportSelector}, vdaf::VdafVerifyKey, DapAbort, DapAggregateShare, DapBatchBucket, DapCollectJob, DapGlobalConfig, - DapMeasurement, DapQueryConfig, DapRequest, DapResource, DapTaskConfig, DapVersion, - MetaAggregationJobId, Prio3Config, VdafConfig, + DapLeaderAggregationJobTransition, DapMeasurement, DapQueryConfig, DapRequest, DapResource, + DapTaskConfig, DapVersion, MetaAggregationJobId, Prio3Config, VdafConfig, }; use assert_matches::assert_matches; use matchit::Router; @@ -484,7 +484,7 @@ mod test { &self, task_id: &TaskId, version: DapVersion, - report_shares: Vec, + reports: Vec, ) -> DapRequest { let mut rng = thread_rng(); let task_config = self.leader.unchecked_get_task_config(task_id).await; @@ -496,18 +496,32 @@ mod test { }; let agg_job_id = MetaAggregationJobId::gen_for_version(&version); + + let DapLeaderAggregationJobTransition::Continued(_leader_state, agg_job_init_req) = + task_config + .vdaf + .produce_agg_job_init_req( + self.leader.as_ref(), + self.leader.as_ref(), + task_id, + &task_config, + &agg_job_id, + &part_batch_sel, + reports, + &self.leader.metrics.with_host("fake-host"), + ) + .await + .unwrap() + else { + panic!("unexpected transition"); + }; + self.leader_authorized_req( task_id, &task_config, Some(&agg_job_id), DapMediaType::AggregationJobInitReq, - AggregationJobInitReq { - draft02_task_id: task_id.for_request_payload(&version), - draft02_agg_job_id: agg_job_id.for_request_payload(), - agg_param: Vec::default(), - part_batch_sel, - report_shares, - }, + agg_job_init_req, task_config.helper_url.join("aggregate").unwrap(), ) .await @@ -766,7 +780,7 @@ mod test { part_batch_sel: PartialBatchSelector::FixedSizeByBatchId { batch_id: BatchId(rng.gen()), }, - report_shares: Vec::default(), + prep_inits: Vec::default(), }, task_config.helper_url.join("aggregate").unwrap(), ) @@ -781,10 +795,41 @@ mod test { async_test_versions! { handle_agg_job_req_invalid_batch_sel } - async fn handle_agg_job_req_init_unauthorized_request(version: DapVersion) { + // TODO(cjpatton) Re-enable this test. We need to refactor so that we can produce the + // AggregationJobInitReq without invoking `produce_agg_job_init_req()`, which filters reports + // passed the expiration date. + // + // async fn handle_agg_job_req_init_expired_task(version: DapVersion) { + // let t = Test::new(version); + // + // let report = t.gen_test_report(&t.expired_task_id).await; + // let report_share = ReportShare { + // report_metadata: report.report_metadata, + // public_share: report.public_share, + // encrypted_input_share: report.encrypted_input_shares[1].clone(), + // }; + // let req = t + // .gen_test_agg_job_init_req(&t.expired_task_id, version, vec![report_share]) + // .await; + // + // let resp = t.helper.handle_agg_job_req(&req).await.unwrap(); + // let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); + // assert_eq!(agg_job_resp.transitions.len(), 1); + // assert_matches!( + // agg_job_resp.transitions[0].var, + // TransitionVar::Failed(TransitionFailure::TaskExpired) + // ); + // + // assert_eq!(t.helper.audit_log.invocations(), 1); + // } + // + // async_test_versions! { handle_agg_job_req_init_expired_task } + + async fn handle_agg_job_init_req_unauthorized_request(version: DapVersion) { let t = Test::new(version); + let report = t.gen_test_report(&t.time_interval_task_id).await; let mut req = t - .gen_test_agg_job_init_req(&t.time_interval_task_id, version, Vec::default()) + .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report]) .await; req.sender_auth = None; @@ -804,51 +849,20 @@ mod test { assert_eq!(t.helper.audit_log.invocations(), 0); } - async_test_versions! { handle_agg_job_req_init_unauthorized_request } - - // Test that the Helper rejects reports past the expiration date. - async fn handle_agg_job_req_init_expired_task(version: DapVersion) { - let t = Test::new(version); - - let report = t.gen_test_report(&t.expired_task_id).await; - let report_share = ReportShare { - report_metadata: report.report_metadata, - public_share: report.public_share, - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }; - let req = t - .gen_test_agg_job_init_req(&t.expired_task_id, version, vec![report_share]) - .await; - - let resp = t.helper.handle_agg_job_req(&req).await.unwrap(); - let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); - assert_eq!(agg_job_resp.transitions.len(), 1); - assert_matches!( - agg_job_resp.transitions[0].var, - TransitionVar::Failed(TransitionFailure::TaskExpired) - ); - - assert_eq!(t.helper.audit_log.invocations(), 1); - } - - async_test_versions! { handle_agg_job_req_init_expired_task } + async_test_versions! { handle_agg_job_init_req_unauthorized_request } // Test that the Helper rejects reports with a bad round number. - async fn handle_agg_job_req_bad_round(version: DapVersion) { - let t = Test::new(version); - if version == DapVersion::Draft02 { + #[tokio::test] + async fn handle_agg_job_req_bad_round_draft02() { + let t = Test::new(DapVersion::Draft02); + if DapVersion::Draft02 == DapVersion::Draft02 { // Nothing to test. return; } let report = t.gen_test_report(&t.time_interval_task_id).await; - let report_share = ReportShare { - report_metadata: report.report_metadata, - public_share: report.public_share, - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }; let req = t - .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report_share]) + .gen_test_agg_job_init_req(&t.time_interval_task_id, DapVersion::Draft02, vec![report]) .await; let agg_job_id = match &req.resource { DapResource::AggregationJob(agg_job_id) => agg_job_id.clone(), @@ -879,56 +893,6 @@ mod test { assert_eq!(t.helper.audit_log.invocations(), 1); } - async_test_versions! { handle_agg_job_req_bad_round } - - // Test that the Helper rejects reports with a bad round id - async fn handle_agg_job_req_zero_round(version: DapVersion) { - let t = Test::new(version); - if version == DapVersion::Draft02 { - // Nothing to test. - return; - } - - let report = t.gen_test_report(&t.time_interval_task_id).await; - let report_share = ReportShare { - report_metadata: report.report_metadata, - public_share: report.public_share, - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }; - let req = t - .gen_test_agg_job_init_req(&t.time_interval_task_id, version, vec![report_share]) - .await; - let agg_job_id = match &req.resource { - DapResource::AggregationJob(agg_job_id) => agg_job_id.clone(), - _ => panic!("agg_job_id resource missing!"), - }; - let resp = t.helper.handle_agg_job_req(&req).await.unwrap(); - - // AggregationJobInitReq succeeds - assert_eq!(t.helper.audit_log.invocations(), 1); - - let agg_job_resp = AggregationJobResp::get_decoded(&resp.payload).unwrap(); - assert_eq!(agg_job_resp.transitions.len(), 1); - assert_matches!(agg_job_resp.transitions[0].var, TransitionVar::Continued(_)); - // Test wrong round - let req = t - .gen_test_agg_job_cont_req_with_round( - &MetaAggregationJobId::Draft07(Cow::Borrowed(&agg_job_id)), - agg_job_resp.transitions, - Some(0), - ) - .await; - assert_matches!( - t.helper.handle_agg_job_req(&req).await, - Err(DapAbort::UnrecognizedMessage { .. }) - ); - - // AggregationJobContinueReq fails - assert_eq!(t.helper.audit_log.invocations(), 1); - } - - async_test_versions! { handle_agg_job_req_zero_round } - async fn handle_hpke_config_req_unrecognized_task(version: DapVersion) { let t = Test::new(version); let mut rng = thread_rng(); @@ -978,7 +942,7 @@ mod test { async_test_versions! { handle_hpke_config_req_missing_task_id } - async fn handle_agg_job_req_cont_unauthorized_request(version: DapVersion) { + async fn handle_agg_job_cont_req_unauthorized_request(version: DapVersion) { let t = Test::new(version); let agg_job_id = MetaAggregationJobId::gen_for_version(&version); let mut req = t @@ -1002,7 +966,7 @@ mod test { assert_eq!(t.helper.audit_log.invocations(), 0); } - async_test_versions! { handle_agg_job_req_cont_unauthorized_request } + async_test_versions! { handle_agg_job_cont_req_unauthorized_request } async fn handle_agg_share_req_unauthorized_request(version: DapVersion) { let t = Test::new(version); @@ -1143,20 +1107,10 @@ mod test { let t = Test::new(version); let task_id = &t.time_interval_task_id; - let report = t.gen_test_report(task_id).await; - let (report_metadata, public_share, mut encrypted_input_share) = ( - report.report_metadata, - report.public_share, - report.encrypted_input_shares[1].clone(), - ); - encrypted_input_share.payload[0] ^= 0xff; // Cause decryption to fail - let report_shares = vec![ReportShare { - report_metadata, - public_share, - encrypted_input_share, - }]; + let mut report = t.gen_test_report(task_id).await; + report.encrypted_input_shares[1].payload[0] ^= 0xff; // Cause decryption to fail let req = t - .gen_test_agg_job_init_req(task_id, version, report_shares) + .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; // Get AggregationJobResp and then extract the transition data from inside. @@ -1180,14 +1134,8 @@ mod test { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let report_shares = vec![ReportShare { - report_metadata: report.report_metadata.clone(), - public_share: report.public_share, - // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }]; let req = t - .gen_test_agg_job_init_req(task_id, version, report_shares) + .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; // Get AggregationJobResp and then extract the transition data from inside. @@ -1208,14 +1156,8 @@ mod test { let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let report_shares = vec![ReportShare { - report_metadata: report.report_metadata.clone(), - public_share: report.public_share, - // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }]; let req = t - .gen_test_agg_job_init_req(task_id, version, report_shares) + .gen_test_agg_job_init_req(task_id, version, vec![report.clone()]) .await; // Add dummy data to report store backend. This is done in a new scope so that the lock on the @@ -1260,14 +1202,8 @@ mod test { let task_config = t.helper.unchecked_get_task_config(task_id).await; let report = t.gen_test_report(task_id).await; - let report_shares = vec![ReportShare { - report_metadata: report.report_metadata.clone(), - public_share: report.public_share, - // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }]; let req = t - .gen_test_agg_job_init_req(task_id, version, report_shares) + .gen_test_agg_job_init_req(task_id, version, vec![report]) .await; // Add mock data to the aggreagte store backend. This is done in its own scope so that the lock @@ -1315,19 +1251,14 @@ mod test { async_test_versions! { handle_agg_job_req_failure_batch_collected } - async fn handle_agg_job_req_abort_helper_state_overwritten(version: DapVersion) { - let t = Test::new(version); + #[tokio::test] + async fn handle_agg_job_req_abort_helper_state_overwritten_draft02() { + let t = Test::new(DapVersion::Draft02); let task_id = &t.time_interval_task_id; let report = t.gen_test_report(task_id).await; - let report_shares = vec![ReportShare { - report_metadata: report.report_metadata.clone(), - public_share: report.public_share, - // 1st share is for Leader and the rest is for Helpers (note that there is only 1 helper). - encrypted_input_share: report.encrypted_input_shares[1].clone(), - }]; let req = t - .gen_test_agg_job_init_req(task_id, version, report_shares) + .gen_test_agg_job_init_req(task_id, DapVersion::Draft02, vec![report]) .await; // Send aggregate request. @@ -1346,8 +1277,6 @@ mod test { ); } - async_test_versions! { handle_agg_job_req_abort_helper_state_overwritten } - async fn handle_agg_job_req_fail_send_cont_req(version: DapVersion) { let t = Test::new(version); let agg_job_id = MetaAggregationJobId::gen_for_version(&version); @@ -1846,8 +1775,14 @@ mod test { let query = task_config.query_for_current_batch_window(t.now); t.run_col_job(task_id, &query).await.unwrap(); + let agg_job_req_count = match version { + DapVersion::Draft02 => 2, + DapVersion::Draft07 => 1, + _ => panic!("unhandled version {version:?}"), + }; + assert_metrics_include!(t.prometheus_registry, { - r#"test_helper_inbound_request_counter{host="helper.org",type="aggregate"}"#: 2, + r#"test_helper_inbound_request_counter{host="helper.org",type="aggregate"}"#: agg_job_req_count, r#"test_helper_inbound_request_counter{host="helper.org",type="collect"}"#: 1, r#"test_leader_report_counter{host="leader.com",status="aggregated"}"#: 1, r#"test_helper_report_counter{host="helper.org",status="aggregated"}"#: 1, @@ -1880,8 +1815,13 @@ mod test { }; t.run_col_job(task_id, &query).await.unwrap(); + let agg_job_req_count = match version { + DapVersion::Draft02 => 2, + DapVersion::Draft07 => 1, + _ => panic!("unhandled version {version:?}"), + }; assert_metrics_include!(t.prometheus_registry, { - r#"test_helper_inbound_request_counter{host="helper.org",type="aggregate"}"#: 2, + r#"test_helper_inbound_request_counter{host="helper.org",type="aggregate"}"#: agg_job_req_count, r#"test_helper_inbound_request_counter{host="helper.org",type="collect"}"#: 1, r#"test_leader_report_counter{host="leader.com",status="aggregated"}"#: 1, r#"test_helper_report_counter{host="helper.org",status="aggregated"}"#: 1, diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index ede4a72bc..4eff107fb 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -18,10 +18,11 @@ use crate::{ metrics::DaphneMetrics, roles::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, DapReportInitializer}, vdaf::{EarlyReportState, EarlyReportStateConsumed, EarlyReportStateInitialized}, - DapAbort, DapAggregateResult, DapAggregateShare, DapAggregateShareSpan, DapBatchBucket, - DapCollectJob, DapError, DapGlobalConfig, DapHelperState, DapHelperTransition, DapLeaderState, - DapLeaderTransition, DapLeaderUncommitted, DapMeasurement, DapQueryConfig, DapRequest, - DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, VdafConfig, + DapAbort, DapAggregateResult, DapAggregateShare, DapAggregateShareSpan, DapAggregationJobState, + DapAggregationJobUncommitted, DapBatchBucket, DapCollectJob, DapError, DapGlobalConfig, + DapHelperAggregationJobTransition, DapLeaderAggregationJobTransition, DapMeasurement, + DapQueryConfig, DapRequest, DapResponse, DapTaskConfig, DapVersion, MetaAggregationJobId, + VdafConfig, }; use assert_matches::assert_matches; use async_trait::async_trait; @@ -190,7 +191,7 @@ impl AggregationJobTest { pub async fn produce_agg_job_init_req( &self, reports: Vec, - ) -> DapLeaderTransition { + ) -> DapLeaderAggregationJobTransition { let metrics = self .leader_metrics .with_host(self.task_config.leader_url.host_str().unwrap()); @@ -216,7 +217,7 @@ impl AggregationJobTest { pub async fn handle_agg_job_init_req( &self, agg_job_init_req: &AggregationJobInitReq, - ) -> DapHelperTransition { + ) -> DapHelperAggregationJobTransition { let metrics = self .helper_metrics .with_host(self.task_config.helper_url.host_str().unwrap()); @@ -227,6 +228,7 @@ impl AggregationJobTest { self, &self.task_id, &self.task_config, + |_id| false, agg_job_init_req, &metrics, ) @@ -239,9 +241,9 @@ impl AggregationJobTest { /// Panics if the Leader aborts. pub fn handle_agg_job_resp( &self, - leader_state: DapLeaderState, + leader_state: DapAggregationJobState, agg_job_resp: AggregationJobResp, - ) -> DapLeaderTransition { + ) -> DapLeaderAggregationJobTransition { let metrics = self .leader_metrics .with_host(self.task_config.leader_url.host_str().unwrap()); @@ -249,10 +251,10 @@ impl AggregationJobTest { .vdaf .handle_agg_job_resp( &self.task_id, + &self.task_config, &self.agg_job_id, leader_state, agg_job_resp, - self.task_config.version, &metrics, ) .unwrap() @@ -261,7 +263,7 @@ impl AggregationJobTest { /// Like [`handle_agg_job_resp`] but expect the Leader to abort. pub fn handle_agg_job_resp_expect_err( &self, - leader_state: DapLeaderState, + leader_state: DapAggregationJobState, agg_job_resp: AggregationJobResp, ) -> DapAbort { let metrics = self @@ -271,10 +273,10 @@ impl AggregationJobTest { .vdaf .handle_agg_job_resp( &self.task_id, + &self.task_config, &self.agg_job_id, leader_state, agg_job_resp, - self.task_config.version, &metrics, ) .expect_err("handle_agg_job_resp() succeeded; expected failure") @@ -285,7 +287,7 @@ impl AggregationJobTest { /// Panics if the Helper aborts. pub fn handle_agg_job_cont_req( &self, - helper_state: &DapHelperState, + helper_state: &DapAggregationJobState, agg_job_cont_req: &AggregationJobContinueReq, ) -> (DapAggregateShareSpan, AggregationJobResp) { let metrics = self @@ -308,7 +310,7 @@ impl AggregationJobTest { /// Like [`handle_agg_job_cont_req`] but expect the Helper to abort. pub fn handle_agg_job_cont_req_expect_err( &self, - helper_state: DapHelperState, + helper_state: DapAggregationJobState, agg_job_cont_req: &AggregationJobContinueReq, ) -> DapAbort { let metrics = self @@ -333,7 +335,7 @@ impl AggregationJobTest { /// Panics if the Leader aborts. pub fn handle_final_agg_job_resp( &self, - leader_uncommitted: DapLeaderUncommitted, + leader_uncommitted: DapAggregationJobUncommitted, agg_job_resp: AggregationJobResp, ) -> DapAggregateShareSpan { let metrics = self @@ -420,28 +422,45 @@ impl AggregationJobTest { let reports = self.produce_reports(measurements); // Aggregators: Preparation - let DapLeaderTransition::Continue(leader_state, agg_job_init_req) = + let DapLeaderAggregationJobTransition::Continued(leader_state, agg_job_init_req) = self.produce_agg_job_init_req(reports).await else { panic!("unexpected transition"); }; - let DapHelperTransition::Continue(helper_state, agg_job_resp) = - self.handle_agg_job_init_req(&agg_job_init_req).await - else { - panic!("unexpected transition"); - }; - let got = DapHelperState::get_decoded(&self.task_config.vdaf, &helper_state.get_encoded()) - .expect("failed to decode helper state"); - assert_eq!(got, helper_state); - let DapLeaderTransition::Uncommitted(uncommitted, agg_cont) = - self.handle_agg_job_resp(leader_state, agg_job_resp) - else { - panic!("unexpected transition"); + let (leader_share_span, helper_share_span) = match self + .handle_agg_job_init_req(&agg_job_init_req) + .await + { + DapHelperAggregationJobTransition::Continued(helper_state, agg_job_resp) => { + let got = DapAggregationJobState::get_decoded( + &self.task_config.vdaf, + &helper_state.get_encoded(), + ) + .expect("failed to decode helper state"); + assert_eq!(got.get_encoded(), helper_state.get_encoded()); + + let DapLeaderAggregationJobTransition::Uncommitted(uncommitted, agg_cont) = + self.handle_agg_job_resp(leader_state, agg_job_resp) + else { + panic!("unexpected transition"); + }; + let (helper_share_span, transitions) = + self.handle_agg_job_cont_req(&helper_state, &agg_cont); + let leader_share_span = self.handle_final_agg_job_resp(uncommitted, transitions); + (leader_share_span, helper_share_span) + } + DapHelperAggregationJobTransition::Finished(Some(helper_share_span), agg_job_resp) => { + let DapLeaderAggregationJobTransition::Finished(Some(leader_share_span)) = + self.handle_agg_job_resp(leader_state, agg_job_resp) + else { + panic!("unexpected transition"); + }; + (leader_share_span, helper_share_span) + } + _ => panic!("unexpected transition"), }; - let (helper_share_span, transitions) = - self.handle_agg_job_cont_req(&helper_state, &agg_cont); - let leader_share_span = self.handle_final_agg_job_resp(uncommitted, transitions); + let report_count = u64::try_from(leader_share_span.report_count()).unwrap(); // Leader: Aggregation @@ -590,7 +609,7 @@ pub struct MockAggregator { pub collector_token: Option, // Not set by Helper pub report_store: Arc>>, pub leader_state_store: Arc>>, - pub helper_state_store: Arc>>, + pub helper_state_store: Arc>>, pub agg_store: Arc>>>, pub collector_hpke_config: HpkeConfig, pub metrics: DaphneMetrics, @@ -1190,7 +1209,7 @@ impl DapHelper for MockAggregator { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - helper_state: &DapHelperState, + helper_state: &DapAggregationJobState, ) -> Result { let helper_state_info = HelperStateInfo { task_id: task_id.clone(), @@ -1219,7 +1238,7 @@ impl DapHelper for MockAggregator { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - ) -> Result, DapError> { + ) -> Result, DapError> { let helper_state_info = HelperStateInfo { task_id: task_id.clone(), agg_job_id_owned: agg_job_id.into(), diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index f52770a34..ee39430f3 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -14,8 +14,8 @@ use crate::{ messages::{ encode_u32_bytes, AggregationJobContinueReq, AggregationJobInitReq, AggregationJobResp, BatchSelector, Extension, HpkeCiphertext, PartialBatchSelector, PlaintextInputShare, - Report, ReportId, ReportMetadata, ReportShare, TaskId, Time, Transition, TransitionFailure, - TransitionVar, + PrepareInit, Report, ReportId, ReportMetadata, ReportShare, TaskId, Time, Transition, + TransitionFailure, TransitionVar, }, metrics::ContextualizedDaphneMetrics, roles::DapReportInitializer, @@ -29,8 +29,9 @@ use crate::{ prio3_prep_init, prio3_shard, prio3_unshard, }, }, - DapAggregateResult, DapAggregateShare, DapAggregateShareSpan, DapError, DapHelperState, - DapHelperTransition, DapLeaderState, DapLeaderTransition, DapLeaderUncommitted, DapMeasurement, + AggregationJobReportState, DapAggregateResult, DapAggregateShare, DapAggregateShareSpan, + DapAggregationJobState, DapAggregationJobUncommitted, DapError, + DapHelperAggregationJobTransition, DapLeaderAggregationJobTransition, DapMeasurement, DapOutputShare, DapTaskConfig, DapVersion, MetaAggregationJobId, VdafConfig, }; use prio::{ @@ -320,7 +321,8 @@ impl EarlyReportState for EarlyReportStateInitialized<'_> { } /// VDAF preparation state. -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone)] +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, Eq, PartialEq))] pub enum VdafPrepState { Prio2(Prio2PrepareState), Prio3Field64(Prio3PrepareState), @@ -372,13 +374,25 @@ impl<'a> ParameterizedDecode<(&'a VdafConfig, bool /* is_leader */)> for VdafPre } /// VDAF preparation message. -#[derive(Clone, Debug)] +#[derive(Clone)] +#[cfg_attr(any(test, feature = "test-utils"), derive(Debug))] pub enum VdafPrepMessage { Prio2Share(Prio2PrepareShare), Prio3ShareField64(Prio3PrepareShare), Prio3ShareField128(Prio3PrepareShare), } +#[cfg(any(test, feature = "test-utils"))] +impl deepsize::DeepSizeOf for VdafPrepMessage { + fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { + match self { + Self::Prio2Share(msg) => std::mem::size_of_val(msg), + Self::Prio3ShareField64(msg) => std::mem::size_of_val(msg), + Self::Prio3ShareField128(msg) => std::mem::size_of_val(msg), + } + } +} + impl Encode for VdafPrepMessage { fn encode(&self, bytes: &mut Vec) { match self { @@ -669,7 +683,7 @@ impl VdafConfig { part_batch_sel: &PartialBatchSelector, reports: Vec, metrics: &ContextualizedDaphneMetrics<'_>, - ) -> Result, DapAbort> { + ) -> Result, DapAbort> { let mut processed = HashSet::with_capacity(reports.len()); let mut states = Vec::with_capacity(reports.len()); let mut seq = Vec::with_capacity(reports.len()); @@ -721,11 +735,30 @@ impl VdafConfig { state, message, } => { - states.push((state, message, metadata.time, metadata.id.clone())); - seq.push(ReportShare { - report_metadata: metadata.into_owned(), - public_share: public_share.into_owned(), - encrypted_input_share: helper_share, + // draft02 compatibility: In the latest version, the Leader sends the Helper + // its initial prep share in the first request. + let (draft02_prep_share, draft05_payload) = match task_config.version { + DapVersion::Draft02 => (Some(message), None), + DapVersion::Draft07 => ( + None, + Some(message.get_encoded_with_param(&task_config.version)), + ), + v => unreachable!("unhandled version {v:?}"), + }; + + states.push(AggregationJobReportState { + draft02_prep_share, + prep_state: state, + time: metadata.time, + report_id: metadata.id.clone(), + }); + seq.push(PrepareInit { + report_share: ReportShare { + report_metadata: metadata.into_owned(), + public_share: public_share.into_owned(), + encrypted_input_share: helper_share, + }, + draft05_payload, }); } @@ -738,11 +771,11 @@ impl VdafConfig { } if seq.is_empty() { - return Ok(DapLeaderTransition::Skip); + return Ok(DapLeaderAggregationJobTransition::Finished(None)); } - Ok(DapLeaderTransition::Continue( - DapLeaderState { + Ok(DapLeaderAggregationJobTransition::Continued( + DapAggregationJobState { seq: states, part_batch_sel: part_batch_sel.clone(), }, @@ -751,7 +784,7 @@ impl VdafConfig { draft02_agg_job_id: agg_job_id.for_request_payload(), agg_param: Vec::default(), part_batch_sel: part_batch_sel.clone(), - report_shares: seq, + prep_inits: seq, }, )) } @@ -777,31 +810,36 @@ impl VdafConfig { /// * `agg_job_init_req` is the request sent by the Leader. /// /// * `version` is the DapVersion to use. + // TODO(cjpatton) Remove this clippy exemption after addressing + // https://github.com/cloudflare/daphne/issues/406. + #[allow(clippy::too_many_arguments)] pub(crate) async fn handle_agg_job_init_req( &self, decrypter: &impl HpkeDecrypter, initializer: &impl DapReportInitializer, task_id: &TaskId, task_config: &DapTaskConfig, + is_replay: impl Fn(&ReportId) -> bool, agg_job_init_req: &AggregationJobInitReq, metrics: &ContextualizedDaphneMetrics<'_>, - ) -> Result, DapAbort> { - let num_reports = agg_job_init_req.report_shares.len(); + ) -> Result, DapAbort> { + let num_reports = agg_job_init_req.prep_inits.len(); let mut processed = HashSet::with_capacity(num_reports); - let mut states = Vec::with_capacity(num_reports); + let mut draft02_states = Vec::with_capacity(num_reports); + let mut draft05_agg_share_span = DapAggregateShareSpan::default(); let mut transitions = Vec::with_capacity(num_reports); let mut consumed_reports = Vec::with_capacity(num_reports); - for report_share in agg_job_init_req.report_shares.iter() { - if processed.contains(&report_share.report_metadata.id) { + for prep_init in agg_job_init_req.prep_inits.iter() { + if processed.contains(&prep_init.report_share.report_metadata.id) { return Err(DapAbort::UnrecognizedMessage { detail: format!( "report ID {} appears twice in the same aggregation job", - report_share.report_metadata.id.to_base64url() + prep_init.report_share.report_metadata.id.to_base64url() ), task_id: Some(task_id.clone()), }); } - processed.insert(report_share.report_metadata.id.clone()); + processed.insert(prep_init.report_share.report_metadata.id.clone()); consumed_reports.push( EarlyReportStateConsumed::consume( @@ -809,9 +847,9 @@ impl VdafConfig { false, task_id, task_config, - Cow::Borrowed(&report_share.report_metadata), - Cow::Borrowed(&report_share.public_share), - &report_share.encrypted_input_share, + Cow::Borrowed(&prep_init.report_share.report_metadata), + Cow::Borrowed(&prep_init.report_share.public_share), + &prep_init.report_share.encrypted_input_share, ) .await?, ); @@ -827,52 +865,123 @@ impl VdafConfig { ) .await?; - for initialized_report in initialized_reports.into_iter() { - let transition = match initialized_report { - EarlyReportStateInitialized::Ready { - metadata, - public_share: _, - state, - message, - } => { - states.push((state, metadata.time, metadata.id.clone())); - Transition { - report_id: metadata.into_owned().id, - var: TransitionVar::Continued(message.get_encoded()), - } - } - - EarlyReportStateInitialized::Rejected { metadata, failure } => { - metrics.report_inc_by(&format!("rejected_{failure}"), 1); - Transition { - report_id: metadata.into_owned().id, - var: TransitionVar::Failed(failure), + for (initialized_report, prep_init) in initialized_reports + .into_iter() + .zip(agg_job_init_req.prep_inits.iter()) + { + let var = if is_replay(&prep_init.report_share.report_metadata.id) { + let failure = TransitionFailure::ReportReplayed; + metrics.report_inc_by(&format!("rejected_{failure}",), 1); + TransitionVar::Failed(failure) + } else { + match initialized_report { + EarlyReportStateInitialized::Ready { + metadata, + public_share: _, + state: helper_prep_state, + message: helper_prep_share, + } => match task_config.version { + DapVersion::Draft02 => { + draft02_states.push(AggregationJobReportState { + draft02_prep_share: None, + prep_state: helper_prep_state, + time: metadata.time, + report_id: metadata.id.clone(), + }); + TransitionVar::Continued(helper_prep_share.get_encoded()) + } + DapVersion::Draft07 => { + let Some(ref leader_prep_share) = prep_init.draft05_payload else { + return Err(DapAbort::UnrecognizedMessage { + detail: "PrepareInit with missing payload".to_string(), + task_id: Some(task_id.clone()), + }); + }; + + let res = match self { + Self::Prio3(prio3_config) => prio3_prep_finish_from_shares( + prio3_config, + 1, + helper_prep_state, + helper_prep_share, + leader_prep_share, + ), + Self::Prio2 { dimension } => prio2_prep_finish_from_shares( + *dimension, + helper_prep_state, + helper_prep_share, + leader_prep_share, + ), + }; + + match res { + Ok((data, prep_msg)) => { + draft05_agg_share_span.add_out_share( + task_config, + &agg_job_init_req.part_batch_sel, + metadata.id.clone(), + metadata.time, + data, + )?; + TransitionVar::Continued(prep_msg) + } + + Err(VdafError::Codec(..)) | Err(VdafError::Vdaf(..)) => { + let failure = TransitionFailure::VdafPrepError; + metrics.report_inc_by(&format!("rejected_{failure}"), 1); + TransitionVar::Failed(failure) + } + } + } + v => unreachable!("unhandled version {v:?}"), + }, + + EarlyReportStateInitialized::Rejected { + metadata: _, + failure, + } => { + metrics.report_inc_by(&format!("rejected_{failure}"), 1); + TransitionVar::Failed(failure) } } }; - transitions.push(transition); + transitions.push(Transition { + report_id: prep_init.report_share.report_metadata.id.clone(), + var, + }); } - Ok(DapHelperTransition::Continue( - DapHelperState { - part_batch_sel: agg_job_init_req.part_batch_sel.clone(), - seq: states, - }, - AggregationJobResp { transitions }, - )) + match task_config.version { + DapVersion::Draft02 => Ok(DapHelperAggregationJobTransition::Continued( + DapAggregationJobState { + part_batch_sel: agg_job_init_req.part_batch_sel.clone(), + seq: draft02_states, + }, + AggregationJobResp { transitions }, + )), + DapVersion::Draft07 => Ok(DapHelperAggregationJobTransition::Finished( + if draft05_agg_share_span.report_count() > 0 { + Some(draft05_agg_share_span) + } else { + None + }, + AggregationJobResp { transitions }, + )), + v => unreachable!("unhandled version {v:?}"), + } } /// Handle an aggregate response from the Helper. This method is run by the Leader. pub fn handle_agg_job_resp( &self, task_id: &TaskId, + task_config: &DapTaskConfig, agg_job_id: &MetaAggregationJobId, - state: DapLeaderState, + state: DapAggregationJobState, agg_job_resp: AggregationJobResp, - version: DapVersion, metrics: &ContextualizedDaphneMetrics<'_>, - ) -> Result, DapAbort> { + ) -> Result, DapAbort> { if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::UnrecognizedMessage { detail: format!( @@ -884,15 +993,15 @@ impl VdafConfig { }); } - let mut seq = Vec::with_capacity(state.seq.len()); - let mut states = Vec::with_capacity(state.seq.len()); - for (helper, (leader_step, leader_message, leader_time, leader_report_id)) in agg_job_resp + let mut draft02_transitions = Vec::with_capacity(state.seq.len()); + let mut draft02_out_shares = Vec::with_capacity(state.seq.len()); + let mut draft05_agg_share_span = DapAggregateShareSpan::default(); + for (helper, leader) in agg_job_resp .transitions .into_iter() .zip(state.seq.into_iter()) { - // TODO spec: Consider removing the report ID from the AggregationJobResp. - if helper.report_id != leader_report_id { + if helper.report_id != leader.report_id { return Err(DapAbort::UnrecognizedMessage { detail: format!( "report ID {} appears out of order in aggregation job response", @@ -902,8 +1011,8 @@ impl VdafConfig { }); } - let helper_message = match &helper.var { - TransitionVar::Continued(message) => message, + let helper_prep_share = match &helper.var { + TransitionVar::Continued(payload) => payload, // Skip report that can't be processed any further. TransitionVar::Failed(failure) => { @@ -920,67 +1029,109 @@ impl VdafConfig { } }; - let res = match self { - Self::Prio3(prio3_config) => prio3_prep_finish_from_shares( - prio3_config, - 0, - leader_step, - leader_message, - helper_message, - ), - Self::Prio2 { dimension } => prio2_prep_finish_from_shares( - *dimension, - leader_step, - leader_message, - helper_message, - ), - }; - - match res { - Ok((data, message)) => { - states.push(( - DapOutputShare { - report_id: leader_report_id.clone(), - time: leader_time, - data, - }, - leader_report_id.clone(), - )); - - seq.push(Transition { - report_id: leader_report_id, - var: TransitionVar::Continued(message), - }); + match task_config.version { + DapVersion::Draft02 => { + let res = match self { + Self::Prio3(prio3_config) => prio3_prep_finish_from_shares( + prio3_config, + 0, + leader.prep_state, + leader.draft02_prep_share.unwrap(), + helper_prep_share, + ), + Self::Prio2 { dimension } => prio2_prep_finish_from_shares( + *dimension, + leader.prep_state, + leader.draft02_prep_share.unwrap(), + helper_prep_share, + ), + }; + + match res { + Ok((data, prep_msg)) => { + draft02_out_shares.push(DapOutputShare { + report_id: leader.report_id.clone(), + time: leader.time, + data, + }); + + draft02_transitions.push(Transition { + report_id: leader.report_id, + var: TransitionVar::Continued(prep_msg), + }); + } + + // Skip report that can't be processed any further. + Err(VdafError::Codec(..)) | Err(VdafError::Vdaf(..)) => { + let failure = TransitionFailure::VdafPrepError; + metrics.report_inc_by(&format!("rejected_{failure}"), 1); + } + }; } - - // Skip report that can't be processed any further. - Err(VdafError::Codec(..)) | Err(VdafError::Vdaf(..)) => { - let failure = TransitionFailure::VdafPrepError; - metrics.report_inc_by(&format!("rejected_{failure}"), 1); + DapVersion::Draft07 => { + let TransitionVar::Continued(ref prep_msg) = helper.var else { + // TODO(cjpatton) Square this with the spec. + return Err(DapAbort::UnrecognizedMessage { + detail: "PrepareResp with unexpected type (expected 'continue')" + .to_string(), + task_id: Some(task_id.clone()), + }); + }; + + let res = match self { + Self::Prio3(prio3_config) => { + prio3_prep_finish(prio3_config, leader.prep_state, prep_msg) + } + Self::Prio2 { dimension } => { + prio2_prep_finish(*dimension, leader.prep_state, prep_msg) + } + }; + + match res { + Ok(data) => { + draft05_agg_share_span.add_out_share( + task_config, + &state.part_batch_sel, + leader.report_id, + leader.time, + data, + )?; + } + + Err(VdafError::Codec(..)) | Err(VdafError::Vdaf(..)) => { + let failure = TransitionFailure::VdafPrepError; + metrics.report_inc_by(&format!("rejected_{failure}"), 1); + } + } } - }; + v => unreachable!("unhandled version {v:?}"), + } } - if seq.is_empty() { - return Ok(DapLeaderTransition::Skip); - } + match task_config.version { + DapVersion::Draft02 => { + if draft02_transitions.is_empty() { + return Ok(DapLeaderAggregationJobTransition::Finished(None)); + } - Ok(DapLeaderTransition::Uncommitted( - DapLeaderUncommitted { - seq: states, - part_batch_sel: state.part_batch_sel, - }, - AggregationJobContinueReq { - draft02_task_id: task_id.for_request_payload(&version), - draft02_agg_job_id: agg_job_id.for_request_payload(), - round: if version == DapVersion::Draft02 { - None - } else { - Some(1) - }, - transitions: seq, - }, - )) + Ok(DapLeaderAggregationJobTransition::Uncommitted( + DapAggregationJobUncommitted { + seq: draft02_out_shares, + part_batch_sel: state.part_batch_sel, + }, + AggregationJobContinueReq { + draft02_task_id: task_id.for_request_payload(&task_config.version), + draft02_agg_job_id: agg_job_id.for_request_payload(), + round: None, + transitions: draft02_transitions, + }, + )) + } + DapVersion::Draft07 => Ok(DapLeaderAggregationJobTransition::Finished(Some( + draft05_agg_share_span, + ))), + v => unreachable!("unhandled version {v:?}"), + } } /// Handle an aggregate request from the Leader. This method is called by the Helper. @@ -998,7 +1149,7 @@ impl VdafConfig { &self, task_id: &TaskId, task_config: &DapTaskConfig, - state: &DapHelperState, + state: &DapAggregationJobState, is_replay: impl Fn(&ReportId) -> bool, agg_job_id: &MetaAggregationJobId<'_>, agg_job_cont_req: &AggregationJobContinueReq, @@ -1026,7 +1177,7 @@ impl VdafConfig { let recognized = state .seq .iter() - .map(|(_, _, report_id)| report_id.clone()) + .map(|report_state| report_state.report_id.clone()) .collect::>(); let mut transitions = Vec::with_capacity(state.seq.len()); let mut agg_share_span = DapAggregateShareSpan::default(); @@ -1060,13 +1211,19 @@ impl VdafConfig { } // Find the next helper report that matches leader.report_id. - let next_helper_report = helper_iter.by_ref().find(|(_, _, id)| { + let next_helper_report = helper_iter.by_ref().find(|report_state| { // Presumably the report was removed from the candidate set by the Leader. - processed.insert(id.clone()); - *id == leader.report_id + processed.insert(report_state.report_id.clone()); + report_state.report_id == leader.report_id }); - let Some((helper_step, helper_time, helper_report_id)) = next_helper_report else { + let Some(AggregationJobReportState { + draft02_prep_share: _, + prep_state, + time, + report_id, + }) = next_helper_report + else { // If the Helper iterator is empty, it means the leader passed in more report ids // than we know about. break; @@ -1091,10 +1248,10 @@ impl VdafConfig { } else { let res = match self { Self::Prio3(prio3_config) => { - prio3_prep_finish(prio3_config, helper_step.clone(), leader_message) + prio3_prep_finish(prio3_config, prep_state.clone(), leader_message) } Self::Prio2 { dimension } => { - prio2_prep_finish(*dimension, helper_step.clone(), leader_message) + prio2_prep_finish(*dimension, prep_state.clone(), leader_message) } }; @@ -1103,8 +1260,8 @@ impl VdafConfig { agg_share_span.add_out_share( task_config, &state.part_batch_sel, - helper_report_id.clone(), - *helper_time, + report_id.clone(), + *time, data, )?; TransitionVar::Finished @@ -1119,7 +1276,7 @@ impl VdafConfig { }; transitions.push(Transition { - report_id: helper_report_id.clone(), + report_id: report_id.clone(), var, }); } @@ -1131,15 +1288,15 @@ impl VdafConfig { pub fn handle_final_agg_job_resp( &self, task_config: &DapTaskConfig, - uncommitted: DapLeaderUncommitted, + state: DapAggregationJobUncommitted, agg_job_resp: AggregationJobResp, metrics: &ContextualizedDaphneMetrics, ) -> Result { - if agg_job_resp.transitions.len() != uncommitted.seq.len() { + if agg_job_resp.transitions.len() != state.seq.len() { return Err(DapAbort::UnrecognizedMessage { detail: format!( "the Leader has {} reports, but it received {} reports from the Helper", - uncommitted.seq.len(), + state.seq.len(), agg_job_resp.transitions.len() ), task_id: None, @@ -1147,11 +1304,8 @@ impl VdafConfig { } let mut agg_share_span = DapAggregateShareSpan::default(); - for (helper, (out_share, leader_report_id)) in - agg_job_resp.transitions.into_iter().zip(uncommitted.seq) - { - // TODO spec: Consider removing the report ID from the AggregationJobResp. - if helper.report_id != leader_report_id { + for (helper, out_share) in agg_job_resp.transitions.into_iter().zip(state.seq) { + if helper.report_id != out_share.report_id { return Err(DapAbort::UnrecognizedMessage { detail: format!( "report ID {} appears out of order in aggregation job response", @@ -1162,7 +1316,6 @@ impl VdafConfig { } match &helper.var { - // TODO Log the fact that the helper sent an unexpected message. TransitionVar::Continued(..) => { return Err(DapAbort::UnrecognizedMessage { detail: "helper sent unexpected `Continued` message".to_string(), @@ -1178,7 +1331,7 @@ impl VdafConfig { TransitionVar::Finished => agg_share_span.add_out_share( task_config, - &uncommitted.part_batch_sel, + &state.part_batch_sel, out_share.report_id.clone(), out_share.time, out_share.data, @@ -1337,18 +1490,19 @@ fn produce_encrypted_agg_share( #[cfg(test)] mod test { use crate::{ - assert_metrics_include, async_test_versions, + assert_metrics_include, assert_metrics_include_auxiliary_function, async_test_versions, error::DapAbort, hpke::{HpkeAeadId, HpkeConfig, HpkeKdfId, HpkeKemId}, messages::{ - AggregationJobInitReq, BatchSelector, Interval, PartialBatchSelector, Report, ReportId, - ReportShare, Transition, TransitionFailure, TransitionVar, + AggregationJobInitReq, BatchSelector, Interval, PartialBatchSelector, PrepareInit, + Report, ReportId, ReportShare, Transition, TransitionFailure, TransitionVar, }, test_versions, testing::AggregationJobTest, - DapAggregateResult, DapAggregateShare, DapError, DapHelperState, DapHelperTransition, - DapLeaderState, DapLeaderTransition, DapLeaderUncommitted, DapMeasurement, DapOutputShare, - DapVersion, Prio3Config, VdafAggregateShare, VdafConfig, VdafPrepMessage, VdafPrepState, + DapAggregateResult, DapAggregateShare, DapAggregateShareSpan, DapAggregationJobState, + DapAggregationJobUncommitted, DapError, DapHelperAggregationJobTransition, + DapLeaderAggregationJobTransition, DapMeasurement, DapVersion, Prio3Config, + VdafAggregateShare, VdafConfig, VdafPrepMessage, VdafPrepState, }; use assert_matches::assert_matches; use hpke_rs::HpkePublicKey; @@ -1365,43 +1519,48 @@ mod test { use super::{EarlyReportStateConsumed, EarlyReportStateInitialized}; - impl DapLeaderTransition { - pub(crate) fn unwrap_continue(self) -> (DapLeaderState, M) { + impl DapLeaderAggregationJobTransition { + fn unwrap_continued(self) -> (DapAggregationJobState, M) { match self { - DapLeaderTransition::Continue(state, message) => (state, message), - _ => { - panic!("unexpected transition: got {:?}", self); - } + Self::Continued(state, message) => (state, message), + _ => panic!("unexpected transition"), } } - pub(crate) fn unwrap_uncommitted(self) -> (DapLeaderUncommitted, M) { + fn unwrap_finished(self) -> DapAggregateShareSpan { match self { - DapLeaderTransition::Uncommitted(uncommitted, message) => (uncommitted, message), - _ => { - panic!("unexpected transition: got {:?}", self); - } + Self::Finished(Some(agg_share_span)) => agg_share_span, + _ => panic!("unexpected transition"), + } + } + + pub(crate) fn unwrap_uncommitted(self) -> (DapAggregationJobUncommitted, M) { + match self { + Self::Uncommitted(uncommitted, message) => (uncommitted, message), + _ => panic!("unexpected transition"), } } } - impl DapHelperTransition { - pub(crate) fn unwrap_continue(self) -> (DapHelperState, M) { + impl DapHelperAggregationJobTransition { + fn unwrap_continued(self) -> (DapAggregationJobState, M) { match self { - DapHelperTransition::Continue(state, message) => (state, message), - _ => { - panic!("unexpected transition: got {:?}", self); - } + Self::Continued(state, message) => (state, message), + _ => panic!("unexpected transition"), } } - #[allow(dead_code)] - pub(crate) fn unwrap_finish(self) -> (Vec, M) { + fn unwrap_finished(self) -> (DapAggregateShareSpan, M) { match self { - DapHelperTransition::Finish(out_shares, message) => (out_shares, message), - _ => { - panic!("unexpected transition: got {:?}", self); - } + Self::Finished(Some(agg_share_span), msg) => (agg_share_span, msg), + _ => panic!("unexpected transition"), + } + } + + fn unwrap_msg(self) -> M { + match self { + Self::Continued(_, msg) => msg, + Self::Finished(_, msg) => msg, } } } @@ -1554,26 +1713,34 @@ mod test { let (leader_state, agg_job_init_req) = t .produce_agg_job_init_req(reports.clone()) .await - .unwrap_continue(); + .unwrap_continued(); assert_eq!(leader_state.seq.len(), 3); assert_eq!( agg_job_init_req.draft02_task_id, t.task_id.for_request_payload(&version) ); assert_eq!(agg_job_init_req.agg_param.len(), 0); - assert_eq!(agg_job_init_req.report_shares.len(), 3); - for (report_shares, report) in agg_job_init_req.report_shares.iter().zip(reports.iter()) { - assert_eq!(report_shares.report_metadata.id, report.report_metadata.id); + assert_eq!(agg_job_init_req.prep_inits.len(), 3); + for (prep_init, report) in agg_job_init_req.prep_inits.iter().zip(reports.iter()) { + assert_eq!( + prep_init.report_share.report_metadata.id, + report.report_metadata.id + ); } - let (helper_state, agg_job_resp) = t - .handle_agg_job_init_req(&agg_job_init_req) - .await - .unwrap_continue(); - assert_eq!(helper_state.seq.len(), 3); - assert_eq!(agg_job_resp.transitions.len(), 3); - for (sub, report) in agg_job_resp.transitions.iter().zip(reports.iter()) { - assert_eq!(sub.report_id, report.report_metadata.id); + match t.handle_agg_job_init_req(&agg_job_init_req).await { + DapHelperAggregationJobTransition::Continued(helper_state, agg_job_resp) => { + assert_eq!(helper_state.seq.len(), 3); + assert_eq!(agg_job_resp.transitions.len(), 3); + for (sub, report) in agg_job_resp.transitions.iter().zip(reports.iter()) { + assert_eq!(sub.report_id, report.report_metadata.id); + } + } + DapHelperAggregationJobTransition::Finished(Some(agg_share_span), agg_job_resp) => { + assert_eq!(agg_share_span.report_count(), 3); + assert_eq!(agg_job_resp.transitions.len(), 3); + } + _ => panic!("unexpected transition"), } } @@ -1588,7 +1755,7 @@ mod test { assert_matches!( t.produce_agg_job_init_req(reports).await, - DapLeaderTransition::Skip + DapLeaderAggregationJobTransition::Finished(None) ); assert_metrics_include!(t.prometheus_registry, { @@ -1607,7 +1774,7 @@ mod test { assert_matches!( t.produce_agg_job_init_req(reports).await, - DapLeaderTransition::Skip + DapLeaderAggregationJobTransition::Finished(None) ); assert_metrics_include!(t.prometheus_registry, { @@ -1626,7 +1793,7 @@ mod test { assert_matches!( t.produce_agg_job_init_req(reports).await, - DapLeaderTransition::Skip + DapLeaderAggregationJobTransition::Finished(None) ); assert_metrics_include!(t.prometheus_registry, { @@ -1646,11 +1813,11 @@ mod test { let (_, agg_job_init_req) = t .produce_agg_job_init_req(reports.clone()) .await - .unwrap_continue(); - let (_, agg_job_resp) = t + .unwrap_continued(); + let agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( @@ -1675,11 +1842,11 @@ mod test { let (_, agg_job_init_req) = t .produce_agg_job_init_req(reports.clone()) .await - .unwrap_continue(); - let (_, agg_job_resp) = t + .unwrap_continued(); + let agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); assert_eq!(agg_job_resp.transitions.len(), 1); assert_matches!( @@ -1706,24 +1873,30 @@ mod test { draft02_agg_job_id: t.agg_job_id.for_request_payload(), agg_param: Vec::new(), part_batch_sel: PartialBatchSelector::TimeInterval, - report_shares: vec![ - ReportShare { - report_metadata: report0.report_metadata, - public_share: report0.public_share, - encrypted_input_share: report0.encrypted_input_shares[1].clone(), + prep_inits: vec![ + PrepareInit { + report_share: ReportShare { + report_metadata: report0.report_metadata, + public_share: report0.public_share, + encrypted_input_share: report0.encrypted_input_shares[1].clone(), + }, + draft05_payload: None, }, - ReportShare { - report_metadata: report1.report_metadata, - public_share: report1.public_share, - encrypted_input_share: report1.encrypted_input_shares[1].clone(), + PrepareInit { + report_share: ReportShare { + report_metadata: report1.report_metadata, + public_share: report1.public_share, + encrypted_input_share: report1.encrypted_input_shares[1].clone(), + }, + draft05_payload: None, }, ], }; - let (_, agg_job_resp) = t + let agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); assert_eq!(agg_job_resp.transitions.len(), 2); assert_matches!( @@ -1746,11 +1919,11 @@ mod test { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); - let (_, mut agg_job_resp) = t + t.produce_agg_job_init_req(reports).await.unwrap_continued(); + let mut agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); // Helper sends transitions out of order. let tmp = agg_job_resp.transitions[0].clone(); @@ -1769,11 +1942,11 @@ mod test { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); - let (_, mut agg_job_resp) = t + t.produce_agg_job_init_req(reports).await.unwrap_continued(); + let mut agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); // Helper sends a transition twice. let repeated_transition = agg_job_resp.transitions[0].clone(); @@ -1792,11 +1965,11 @@ mod test { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); - let (_, mut agg_job_resp) = t + t.produce_agg_job_init_req(reports).await.unwrap_continued(); + let mut agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); // Helper sent a transition with an unrecognized report ID. agg_job_resp.transitions.push(Transition { @@ -1816,11 +1989,11 @@ mod test { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let reports = t.produce_reports(vec![DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); - let (_, mut agg_job_resp) = t + t.produce_agg_job_init_req(reports).await.unwrap_continued(); + let mut agg_job_resp = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_msg(); // Helper sent a transition with an unrecognized report ID. agg_job_resp.transitions[0].var = TransitionVar::Finished; @@ -1842,28 +2015,46 @@ mod test { DapMeasurement::U64(0), DapMeasurement::U64(1), ]); + let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); - let (helper_state, agg_job_resp) = t - .handle_agg_job_init_req(&agg_job_init_req) - .await - .unwrap_continue(); + t.produce_agg_job_init_req(reports).await.unwrap_continued(); - let (leader_uncommitted, agg_job_cont_req) = t - .handle_agg_job_resp(leader_state, agg_job_resp) - .unwrap_uncommitted(); + let (leader_agg_share_span, helper_agg_share_span) = + match t.handle_agg_job_init_req(&agg_job_init_req).await { + DapHelperAggregationJobTransition::Continued(helper_state, agg_job_resp) => { + // draft02 + let (leader_uncommitted, agg_job_cont_req) = t + .handle_agg_job_resp(leader_state, agg_job_resp) + .unwrap_uncommitted(); - let (helper_agg_share_span, agg_job_resp) = - t.handle_agg_job_cont_req(&helper_state, &agg_job_cont_req); - assert_eq!(helper_agg_share_span.report_count(), 5); - assert_eq!(agg_job_resp.transitions.len(), 5); + let (helper_agg_share_span, agg_job_resp) = + t.handle_agg_job_cont_req(&helper_state, &agg_job_cont_req); + assert_eq!(helper_agg_share_span.report_count(), 5); + assert_eq!(agg_job_resp.transitions.len(), 5); - let leader_share_span = t.handle_final_agg_job_resp(leader_uncommitted, agg_job_resp); - assert_eq!(leader_share_span.report_count(), 5); - let num_measurements = leader_share_span.report_count(); + let leader_agg_share_span = + t.handle_final_agg_job_resp(leader_uncommitted, agg_job_resp); + + (leader_agg_share_span, helper_agg_share_span) + } + DapHelperAggregationJobTransition::Finished( + Some(helper_agg_share_span), + agg_job_resp, + ) => { + let leader_agg_share_span = t + .handle_agg_job_resp(leader_state, agg_job_resp) + .unwrap_finished(); + + (leader_agg_share_span, helper_agg_share_span) + } + _ => panic!("unexpected transition"), + }; + + assert_eq!(leader_agg_share_span.report_count(), 5); + let num_measurements = leader_agg_share_span.report_count(); let VdafAggregateShare::Field64(leader_agg_share) = - leader_share_span.collapsed().data.unwrap() + leader_agg_share_span.collapsed().data.unwrap() else { panic!("unexpected VdafAggregateShare variant") }; @@ -1884,20 +2075,22 @@ mod test { async_test_versions! { agg_job_cont_req } - async fn agg_job_cont_req_skip_vdaf_prep_error(version: DapVersion) { - let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + #[tokio::test] + async fn agg_job_cont_req_skip_vdaf_prep_error_draft02() { + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft02); let mut reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); reports.insert( 1, - t.produce_invalid_report_vdaf_prep_failure(DapMeasurement::U64(1), version), + t.produce_invalid_report_vdaf_prep_failure(DapMeasurement::U64(1), DapVersion::Draft02), ); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); + t.produce_agg_job_init_req(reports).await.unwrap_continued(); let (helper_state, agg_job_resp) = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_continued(); let (_, agg_job_cont_req) = t .handle_agg_job_resp(leader_state, agg_job_resp) @@ -1910,11 +2103,17 @@ mod test { assert_eq!(2, agg_job_resp.transitions.len()); assert_eq!( agg_job_resp.transitions[0].report_id, - agg_job_init_req.report_shares[0].report_metadata.id + agg_job_init_req.prep_inits[0] + .report_share + .report_metadata + .id ); assert_eq!( agg_job_resp.transitions[1].report_id, - agg_job_init_req.report_shares[2].report_metadata.id + agg_job_init_req.prep_inits[2] + .report_share + .report_metadata + .id ); assert_metrics_include!(t.prometheus_registry, { @@ -1922,18 +2121,55 @@ mod test { }); } - async_test_versions! { agg_job_cont_req_skip_vdaf_prep_error } + #[tokio::test] + async fn agg_job_init_req_skip_vdaf_prep_error_draft07() { + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft07); + let mut reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); + reports.insert( + 1, + t.produce_invalid_report_vdaf_prep_failure(DapMeasurement::U64(1), DapVersion::Draft07), + ); - async fn agg_cont_abort_unrecognized_report_id(version: DapVersion) { + let (leader_state, agg_job_init_req) = + t.produce_agg_job_init_req(reports).await.unwrap_continued(); + let (helper_agg_share_span, agg_job_resp) = t + .handle_agg_job_init_req(&agg_job_init_req) + .await + .unwrap_finished(); + + assert_eq!(2, helper_agg_share_span.report_count()); + assert_eq!(3, agg_job_resp.transitions.len()); + for i in 0..3 { + assert_eq!( + agg_job_resp.transitions[i].report_id, + agg_job_init_req.prep_inits[i] + .report_share + .report_metadata + .id + ); + } + + let _leader_agg_share_span = t.handle_agg_job_resp(leader_state, agg_job_resp); + + assert_metrics_include!(t.prometheus_registry, { + r#"test_leader_report_counter{host="leader.com",status="rejected_vdaf_prep_error"}"#: 1, + r#"test_helper_report_counter{host="helper.org",status="rejected_vdaf_prep_error"}"#: 1, + }); + } + + #[tokio::test] + async fn agg_cont_abort_unrecognized_report_id_draft02() { let mut rng = thread_rng(); - let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft02); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); + t.produce_agg_job_init_req(reports).await.unwrap_continued(); let (helper_state, agg_job_resp) = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_continued(); let (_, mut agg_job_cont_req) = t .handle_agg_job_resp(leader_state, agg_job_resp) @@ -1953,17 +2189,17 @@ mod test { ); } - async_test_versions! { agg_cont_abort_unrecognized_report_id } - - async fn agg_job_cont_req_abort_transition_out_of_order(version: DapVersion) { - let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + #[tokio::test] + async fn agg_job_cont_req_abort_transition_out_of_order_draft02() { + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft02); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); + t.produce_agg_job_init_req(reports).await.unwrap_continued(); let (helper_state, agg_job_resp) = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_continued(); let (_, mut agg_job_cont_req) = t .handle_agg_job_resp(leader_state, agg_job_resp) @@ -1979,17 +2215,17 @@ mod test { ); } - async_test_versions! { agg_job_cont_req_abort_transition_out_of_order } - - async fn agg_job_cont_req_abort_report_id_repeated(version: DapVersion) { - let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + #[tokio::test] + async fn agg_job_cont_req_abort_report_id_repeated_draft02() { + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft02); let reports = t.produce_reports(vec![DapMeasurement::U64(1), DapMeasurement::U64(1)]); let (leader_state, agg_job_init_req) = - t.produce_agg_job_init_req(reports).await.unwrap_continue(); + t.produce_agg_job_init_req(reports).await.unwrap_continued(); let (helper_state, agg_job_resp) = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_continued(); let (_, mut agg_job_cont_req) = t .handle_agg_job_resp(leader_state, agg_job_resp) @@ -2004,8 +2240,6 @@ mod test { ); } - async_test_versions! { agg_job_cont_req_abort_report_id_repeated } - async fn encrypted_agg_share(version: DapVersion) { let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); let leader_agg_share = DapAggregateShare { @@ -2050,8 +2284,10 @@ mod test { async_test_versions! { encrypted_agg_share } - async fn helper_state_serialization(version: DapVersion) { - let t = AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, version); + #[tokio::test] + async fn helper_state_serialization_draft02() { + let t = + AggregationJobTest::new(TEST_VDAF, HpkeKemId::X25519HkdfSha256, DapVersion::Draft02); let reports = t.produce_reports(vec![ DapMeasurement::U64(1), DapMeasurement::U64(1), @@ -2059,20 +2295,18 @@ mod test { DapMeasurement::U64(0), DapMeasurement::U64(1), ]); - let (_, agg_job_init_req) = t.produce_agg_job_init_req(reports).await.unwrap_continue(); + let (_, agg_job_init_req) = t.produce_agg_job_init_req(reports).await.unwrap_continued(); let (want, _) = t .handle_agg_job_init_req(&agg_job_init_req) .await - .unwrap_continue(); + .unwrap_continued(); - let got = DapHelperState::get_decoded(TEST_VDAF, &want.get_encoded()).unwrap(); - assert_eq!(got, want); + let got = DapAggregationJobState::get_decoded(TEST_VDAF, &want.get_encoded()).unwrap(); + assert_eq!(got.get_encoded(), want.get_encoded()); - assert!(DapHelperState::get_decoded(TEST_VDAF, b"invalid helper state").is_err()) + assert!(DapAggregationJobState::get_decoded(TEST_VDAF, b"invalid helper state").is_err()) } - async_test_versions! { helper_state_serialization } - impl AggregationJobTest { // Tweak the Helper's share so that decoding succeeds but preparation fails. fn produce_invalid_report_vdaf_prep_failure( diff --git a/daphne_worker/src/roles/helper.rs b/daphne_worker/src/roles/helper.rs index 60113efc4..fe9a411de 100644 --- a/daphne_worker/src/roles/helper.rs +++ b/daphne_worker/src/roles/helper.rs @@ -16,8 +16,8 @@ use crate::{ }; use async_trait::async_trait; use daphne::{ - error::DapAbort, fatal_error, messages::TaskId, roles::DapHelper, DapError, DapHelperState, - MetaAggregationJobId, + error::DapAbort, fatal_error, messages::TaskId, roles::DapHelper, DapAggregationJobState, + DapError, MetaAggregationJobId, }; use prio::codec::Encode; @@ -27,7 +27,7 @@ impl<'srv> DapHelper for DaphneWorker<'srv> { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - helper_state: &DapHelperState, + helper_state: &DapAggregationJobState, ) -> std::result::Result { let task_config = self.try_get_task_config(task_id).await?; let helper_state_hex = hex::encode(helper_state.get_encoded()); @@ -48,7 +48,7 @@ impl<'srv> DapHelper for DaphneWorker<'srv> { &self, task_id: &TaskId, agg_job_id: &MetaAggregationJobId, - ) -> std::result::Result, DapError> { + ) -> std::result::Result, DapError> { let task_config = self.try_get_task_config(task_id).await?; // TODO(cjpatton) Figure out if retry is safe, since the request is not actually // idempotent. (It removes the helper's state from storage if it exists.) @@ -67,7 +67,8 @@ impl<'srv> DapHelper for DaphneWorker<'srv> { Some(helper_state_hex) => { let data = hex::decode(helper_state_hex) .map_err(|e| DapAbort::from_hex_error(e, task_id.clone()))?; - let helper_state = DapHelperState::get_decoded(&task_config.as_ref().vdaf, &data)?; + let helper_state = + DapAggregationJobState::get_decoded(&task_config.as_ref().vdaf, &data)?; Ok(Some(helper_state)) } None => Ok(None),