diff --git a/daphne/src/roles/helper.rs b/daphne/src/roles/helper.rs index 4c122fce2..32b3b608a 100644 --- a/daphne/src/roles/helper.rs +++ b/daphne/src/roles/helper.rs @@ -193,7 +193,7 @@ pub trait DapHelper: DapAggregator { return Err(DapAbort::version_mismatch(req.version, task_config.version)); } - let mut agg_job_cont_req = + let agg_job_cont_req = AggregationJobContinueReq::get_decoded_with_param(&req.version, &req.payload) .map_err(|e| DapAbort::from_codec_error(e, task_id.clone()))?; @@ -218,9 +218,8 @@ pub trait DapHelper: DapAggregator { // against them, as such, even though retrying is possibly very expensive, it probably // won't happen often enough that it matters. const RETRY_COUNT: u32 = 3; - // Slots for the output transitions - let mut output_transitions: Vec<(ReportId, Option)> = agg_job_cont_req - .transitions + let mut input_transitions = agg_job_cont_req.transitions; + let mut output_transitions: Vec<(ReportId, Option)> = input_transitions .iter() .map(|t| &t.report_id) .cloned() @@ -231,8 +230,9 @@ pub trait DapHelper: DapAggregator { task_id, task_config, &state, + agg_job_cont_req.round, &agg_job_id, - &agg_job_cont_req, + &input_transitions, )?; let put_shares_result = self @@ -254,8 +254,7 @@ pub trait DapHelper: DapAggregator { // we got 0 replays (Ok(replays), reports) if replays.is_empty() => { // remove successfull reports in this batch as they are finished - agg_job_cont_req - .transitions + input_transitions .retain(|t| reports.iter().all(|(id, _)| *id != t.report_id)); // set those reports as finished in the output_transitions output_transitions @@ -266,9 +265,7 @@ pub trait DapHelper: DapAggregator { // this bucket had replays (Ok(replays), _) => { // remove the replays from the input transitions - agg_job_cont_req - .transitions - .retain(|t| !replays.contains(&t.report_id)); + input_transitions.retain(|t| !replays.contains(&t.report_id)); // mark the replayed reports as replays in the output_transitions output_transitions .iter_mut() diff --git a/daphne/src/testing.rs b/daphne/src/testing.rs index 7655a636e..7822c6c92 100644 --- a/daphne/src/testing.rs +++ b/daphne/src/testing.rs @@ -294,8 +294,9 @@ impl AggregationJobTest { &self.task_id, &self.task_config, helper_state, + agg_job_cont_req.round, &self.agg_job_id, - agg_job_cont_req, + &agg_job_cont_req.transitions, ) .unwrap() } @@ -312,8 +313,9 @@ impl AggregationJobTest { &self.task_id, &self.task_config, &helper_state, + agg_job_cont_req.round, &self.agg_job_id, - agg_job_cont_req, + &agg_job_cont_req.transitions, ) .expect_err("handle_agg_job_cont_req() succeeded; expected failure") } diff --git a/daphne/src/vdaf/mod.rs b/daphne/src/vdaf/mod.rs index 10b5cbc48..5024e33d3 100644 --- a/daphne/src/vdaf/mod.rs +++ b/daphne/src/vdaf/mod.rs @@ -1000,10 +1000,11 @@ impl VdafConfig { task_id: &TaskId, task_config: &DapTaskConfig, state: &DapHelperState, + round: Option, agg_job_id: &MetaAggregationJobId<'_>, - agg_job_cont_req: &AggregationJobContinueReq, + input_transitions: &[Transition], ) -> Result<(DapAggregateSpan, AggregationJobResp), DapAbort> { - match agg_job_cont_req.round { + match round { Some(1) | None => {} Some(0) => { return Err(DapAbort::UnrecognizedMessage { @@ -1030,7 +1031,7 @@ impl VdafConfig { let mut transitions = Vec::with_capacity(state.seq.len()); let mut agg_share_span = DapAggregateSpan::default(); let mut helper_iter = state.seq.iter(); - for leader in &agg_job_cont_req.transitions { + for leader in input_transitions { // If the report ID is not recognized, then respond with a transition failure. // // TODO spec: Having to enforce this is awkward because, in order to disambiguate the