Skip to content

Commit 66dea79

Browse files
authored
Message & request changes for async aggregation/collection. (#3522)
Implementation of asynchronous aggregation is forthcoming. With this change, we use the Pending/Finished statuses for aggregation/collection as-specified, but both the Leader & Helper support only synchronous aggregation.
1 parent a2c0398 commit 66dea79

19 files changed

+852
-569
lines changed

aggregator/src/aggregator.rs

+53-47
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ use janus_messages::{
6868
taskprov::TaskConfig,
6969
AggregateShare, AggregateShareAad, AggregateShareReq, AggregationJobContinueReq,
7070
AggregationJobId, AggregationJobInitializeReq, AggregationJobResp, AggregationJobStep,
71-
BatchSelector, Collection, CollectionJobId, CollectionReq, Duration, ExtensionType, HpkeConfig,
72-
HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
71+
BatchSelector, CollectionJobId, CollectionJobReq, CollectionJobResp, Duration, ExtensionType,
72+
HpkeConfig, HpkeConfigList, InputShareAad, Interval, PartialBatchSelector, PlaintextInputShare,
7373
PrepareResp, PrepareStepResult, Report, ReportError, ReportIdChecksum, ReportShare, Role,
7474
TaskId,
7575
};
@@ -529,14 +529,14 @@ impl<C: Clock> Aggregator<C> {
529529
}
530530

531531
/// Handle a collection job creation request. Only supported by the leader. `req_bytes` is an
532-
/// encoded [`CollectionReq`].
532+
/// encoded [`CollectionJobReq`]. Returns an encoded [`CollectionJobResp`] on success.
533533
async fn handle_create_collection_job(
534534
&self,
535535
task_id: &TaskId,
536536
collection_job_id: &CollectionJobId,
537537
req_bytes: &[u8],
538538
auth_token: Option<AuthenticationToken>,
539-
) -> Result<(), Error> {
539+
) -> Result<Vec<u8>, Error> {
540540
let task_aggregator = self
541541
.task_aggregators
542542
.get(task_id)
@@ -566,7 +566,7 @@ impl<C: Clock> Aggregator<C> {
566566
task_id: &TaskId,
567567
collection_job_id: &CollectionJobId,
568568
auth_token: Option<AuthenticationToken>,
569-
) -> Result<Option<Vec<u8>>, Error> {
569+
) -> Result<Vec<u8>, Error> {
570570
let task_aggregator = self
571571
.task_aggregators
572572
.get(task_id)
@@ -1034,7 +1034,7 @@ impl<C: Clock> TaskAggregator<C> {
10341034
datastore: &Datastore<C>,
10351035
collection_job_id: &CollectionJobId,
10361036
req_bytes: &[u8],
1037-
) -> Result<(), Error> {
1037+
) -> Result<Vec<u8>, Error> {
10381038
self.vdaf_ops
10391039
.handle_create_collection_job(
10401040
datastore,
@@ -1049,7 +1049,7 @@ impl<C: Clock> TaskAggregator<C> {
10491049
&self,
10501050
datastore: &Datastore<C>,
10511051
collection_job_id: &CollectionJobId,
1052-
) -> Result<Option<Vec<u8>>, Error> {
1052+
) -> Result<Vec<u8>, Error> {
10531053
self.vdaf_ops
10541054
.handle_get_collection_job(datastore, Arc::clone(&self.task), collection_job_id)
10551055
.await
@@ -1834,19 +1834,20 @@ impl VdafOps {
18341834
}
18351835

18361836
// This is a repeated request. Send the same response we computed last time.
1837-
return Ok(Some(AggregationJobResp::new(
1838-
tx.get_report_aggregations_for_aggregation_job(
1839-
vdaf,
1840-
&Role::Helper,
1841-
task_id,
1842-
aggregation_job_id,
1843-
)
1844-
.await?
1845-
.iter()
1846-
.filter_map(ReportAggregation::last_prep_resp)
1847-
.cloned()
1848-
.collect(),
1849-
)));
1837+
return Ok(Some(AggregationJobResp::Finished {
1838+
prepare_resps: tx
1839+
.get_report_aggregations_for_aggregation_job(
1840+
vdaf,
1841+
&Role::Helper,
1842+
task_id,
1843+
aggregation_job_id,
1844+
)
1845+
.await?
1846+
.iter()
1847+
.filter_map(ReportAggregation::last_prep_resp)
1848+
.cloned()
1849+
.collect(),
1850+
}));
18501851
}
18511852

18521853
Ok(None)
@@ -2358,11 +2359,11 @@ impl VdafOps {
23582359
let (mut prep_resps_by_agg_job, counters) =
23592360
aggregation_job_writer.write(tx, vdaf).await?;
23602361
Ok((
2361-
AggregationJobResp::new(
2362-
prep_resps_by_agg_job
2362+
AggregationJobResp::Finished {
2363+
prepare_resps: prep_resps_by_agg_job
23632364
.remove(aggregation_job.id())
23642365
.unwrap_or_default(),
2365-
),
2366+
},
23662367
counters,
23672368
))
23682369
})
@@ -2473,13 +2474,13 @@ impl VdafOps {
24732474
}
24742475
}
24752476
return Ok((
2476-
AggregationJobResp::new(
2477-
report_aggregations
2477+
AggregationJobResp::Finished {
2478+
prepare_resps: report_aggregations
24782479
.iter()
24792480
.filter_map(ReportAggregation::last_prep_resp)
24802481
.cloned()
24812482
.collect(),
2482-
),
2483+
},
24832484
TaskAggregationCounter::default(),
24842485
));
24852486
} else if aggregation_job.step().increment() != req.step() {
@@ -2574,7 +2575,7 @@ impl VdafOps {
25742575
task: Arc<AggregatorTask>,
25752576
collection_job_id: &CollectionJobId,
25762577
collection_req_bytes: &[u8],
2577-
) -> Result<(), Error> {
2578+
) -> Result<Vec<u8>, Error> {
25782579
match task.batch_mode() {
25792580
task::BatchMode::TimeInterval => {
25802581
vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => {
@@ -2612,19 +2613,19 @@ impl VdafOps {
26122613
vdaf: Arc<A>,
26132614
collection_job_id: &CollectionJobId,
26142615
req_bytes: &[u8],
2615-
) -> Result<(), Error>
2616+
) -> Result<Vec<u8>, Error>
26162617
where
26172618
A::AggregationParam: 'static + Send + Sync + PartialEq + Eq + Hash,
26182619
A::AggregateShare: Send + Sync,
26192620
{
26202621
let req =
2621-
Arc::new(CollectionReq::<B>::get_decoded(req_bytes).map_err(Error::MessageDecode)?);
2622+
Arc::new(CollectionJobReq::<B>::get_decoded(req_bytes).map_err(Error::MessageDecode)?);
26222623
let aggregation_param = Arc::new(
26232624
A::AggregationParam::get_decoded(req.aggregation_parameter())
26242625
.map_err(Error::MessageDecode)?,
26252626
);
26262627

2627-
Ok(datastore
2628+
datastore
26282629
.run_tx("collect", move |tx| {
26292630
let (task, vdaf, collection_job_id, req, aggregation_param) = (
26302631
Arc::clone(&task),
@@ -2715,7 +2716,11 @@ impl VdafOps {
27152716
Ok(())
27162717
})
27172718
})
2718-
.await?)
2719+
.await?;
2720+
2721+
CollectionJobResp::<B>::Processing
2722+
.get_encoded()
2723+
.map_err(Error::MessageEncode)
27192724
}
27202725

27212726
/// Handle GET requests to the leader's `tasks/{task-id}/collection_jobs/{collection-job-id}`
@@ -2727,7 +2732,7 @@ impl VdafOps {
27272732
datastore: &Datastore<C>,
27282733
task: Arc<AggregatorTask>,
27292734
collection_job_id: &CollectionJobId,
2730-
) -> Result<Option<Vec<u8>>, Error> {
2735+
) -> Result<Vec<u8>, Error> {
27312736
match task.batch_mode() {
27322737
task::BatchMode::TimeInterval => {
27332738
vdaf_ops_dispatch!(self, (vdaf, _, VdafType, VERIFY_KEY_LENGTH) => {
@@ -2765,7 +2770,7 @@ impl VdafOps {
27652770
task: Arc<AggregatorTask>,
27662771
vdaf: Arc<A>,
27672772
collection_job_id: &CollectionJobId,
2768-
) -> Result<Option<Vec<u8>>, Error>
2773+
) -> Result<Vec<u8>, Error>
27692774
where
27702775
A::AggregationParam: Send + Sync,
27712776
A::AggregateShare: Send + Sync,
@@ -2790,7 +2795,7 @@ impl VdafOps {
27902795
match collection_job.state() {
27912796
CollectionJobState::Start => {
27922797
debug!(%collection_job_id, task_id = %task.id(), "collection job has not run yet");
2793-
Ok(None)
2798+
Ok(CollectionJobResp::<B>::Processing)
27942799
}
27952800

27962801
CollectionJobState::Finished {
@@ -2839,19 +2844,15 @@ impl VdafOps {
28392844
.map_err(Error::MessageEncode)?,
28402845
)?;
28412846

2842-
Ok(Some(
2843-
Collection::<B>::new(
2844-
PartialBatchSelector::new(
2845-
B::partial_batch_identifier(collection_job.batch_identifier()).clone(),
2846-
),
2847-
*report_count,
2848-
*client_timestamp_interval,
2849-
encrypted_leader_aggregate_share,
2850-
encrypted_helper_aggregate_share.clone(),
2851-
)
2852-
.get_encoded()
2853-
.map_err(Error::MessageEncode)?,
2854-
))
2847+
Ok(CollectionJobResp::<B>::Finished {
2848+
partial_batch_selector: PartialBatchSelector::new(
2849+
B::partial_batch_identifier(collection_job.batch_identifier()).clone(),
2850+
),
2851+
report_count: *report_count,
2852+
interval: *client_timestamp_interval,
2853+
leader_encrypted_agg_share: encrypted_leader_aggregate_share,
2854+
helper_encrypted_agg_share: encrypted_helper_aggregate_share.clone(),
2855+
})
28552856
}
28562857

28572858
CollectionJobState::Abandoned => Err(Error::AbandonedCollectionJob(
@@ -2863,6 +2864,11 @@ impl VdafOps {
28632864
Err(Error::DeletedCollectionJob(*task.id(), *collection_job_id))
28642865
}
28652866
}
2867+
.and_then(|collection_job_resp| {
2868+
collection_job_resp
2869+
.get_encoded()
2870+
.map_err(Error::MessageEncode)
2871+
})
28662872
}
28672873

28682874
#[tracing::instrument(skip(self, datastore, task), fields(task_id = ?task.id()), err(level = Level::DEBUG))]

aggregator/src/aggregator/aggregate_init_tests.rs

+27-17
Original file line numberDiff line numberDiff line change
@@ -212,19 +212,23 @@ async fn setup_aggregate_init_test_for_vdaf<
212212
&test_case.handler,
213213
)
214214
.await;
215-
assert_eq!(response.status(), Some(Status::Ok));
215+
assert_eq!(response.status(), Some(Status::Created));
216216

217-
let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await;
217+
let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
218+
let prepare_resps = assert_matches!(
219+
&aggregation_job_resp,
220+
AggregationJobResp::Finished { prepare_resps } => prepare_resps
221+
);
218222
assert_eq!(
219-
aggregation_job_init_resp.prepare_resps().len(),
223+
prepare_resps.len(),
220224
test_case.aggregation_job_init_req.prepare_inits().len(),
221225
);
222226
assert_matches!(
223-
aggregation_job_init_resp.prepare_resps()[0].result(),
227+
prepare_resps[0].result(),
224228
&PrepareStepResult::Continue { .. }
225229
);
226230

227-
test_case.aggregation_job_init_resp = Some(aggregation_job_init_resp);
231+
test_case.aggregation_job_init_resp = Some(aggregation_job_resp);
228232
test_case
229233
}
230234

@@ -345,7 +349,7 @@ async fn aggregation_job_init_authorization_dap_auth_token() {
345349
.run_async(&test_case.handler)
346350
.await;
347351

348-
assert_eq!(response.status(), Some(Status::Ok));
352+
assert_eq!(response.status(), Some(Status::Created));
349353
}
350354

351355
#[rstest::rstest]
@@ -420,12 +424,14 @@ async fn aggregation_job_init_unexpected_taskprov_extension() {
420424
&test_case.handler,
421425
)
422426
.await;
423-
assert_eq!(response.status(), Some(Status::Ok));
424-
425-
let want_aggregation_job_resp = AggregationJobResp::new(Vec::from([PrepareResp::new(
426-
report_id,
427-
PrepareStepResult::Reject(ReportError::InvalidMessage),
428-
)]));
427+
assert_eq!(response.status(), Some(Status::Created));
428+
429+
let want_aggregation_job_resp = AggregationJobResp::Finished {
430+
prepare_resps: Vec::from([PrepareResp::new(
431+
report_id,
432+
PrepareStepResult::Reject(ReportError::InvalidMessage),
433+
)]),
434+
};
429435
let got_aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
430436
assert_eq!(want_aggregation_job_resp, got_aggregation_job_resp);
431437
}
@@ -589,19 +595,23 @@ async fn aggregation_job_intolerable_clock_skew() {
589595
&test_case.handler,
590596
)
591597
.await;
592-
assert_eq!(response.status(), Some(Status::Ok));
598+
assert_eq!(response.status(), Some(Status::Created));
593599

594-
let aggregation_job_init_resp: AggregationJobResp = decode_response_body(&mut response).await;
600+
let aggregation_job_resp: AggregationJobResp = decode_response_body(&mut response).await;
601+
let prepare_resps = assert_matches!(
602+
aggregation_job_resp,
603+
AggregationJobResp::Finished { prepare_resps } => prepare_resps
604+
);
595605
assert_eq!(
596-
aggregation_job_init_resp.prepare_resps().len(),
606+
prepare_resps.len(),
597607
test_case.aggregation_job_init_req.prepare_inits().len(),
598608
);
599609
assert_matches!(
600-
aggregation_job_init_resp.prepare_resps()[0].result(),
610+
prepare_resps[0].result(),
601611
&PrepareStepResult::Continue { .. }
602612
);
603613
assert_matches!(
604-
aggregation_job_init_resp.prepare_resps()[1].result(),
614+
prepare_resps[1].result(),
605615
&PrepareStepResult::Reject(ReportError::ReportTooEarly)
606616
);
607617
}

aggregator/src/aggregator/aggregation_job_continue.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -286,11 +286,11 @@ impl VdafOps {
286286
aggregation_job_writer.put(aggregation_job, report_aggregations_to_write)?;
287287
let (mut prep_resps_by_agg_job, counters) = aggregation_job_writer.write(tx, vdaf).await?;
288288
Ok((
289-
AggregationJobResp::new(
290-
prep_resps_by_agg_job
289+
AggregationJobResp::Finished {
290+
prepare_resps: prep_resps_by_agg_job
291291
.remove(&aggregation_job_id)
292292
.unwrap_or_default(),
293-
),
293+
},
294294
counters,
295295
))
296296
}
@@ -334,7 +334,7 @@ pub mod test_util {
334334
) -> AggregationJobResp {
335335
let mut test_conn = post_aggregation_job(task, aggregation_job_id, request, handler).await;
336336

337-
assert_eq!(test_conn.status(), Some(Status::Ok));
337+
assert_eq!(test_conn.status(), Some(Status::Accepted));
338338
assert_headers!(&test_conn, "content-type" => (AggregationJobResp::MEDIA_TYPE));
339339
decode_response_body::<AggregationJobResp>(&mut test_conn).await
340340
}
@@ -582,14 +582,14 @@ mod tests {
582582
// Validate response.
583583
assert_eq!(
584584
first_continue_response,
585-
AggregationJobResp::new(
586-
test_case
585+
AggregationJobResp::Finished {
586+
prepare_resps: test_case
587587
.first_continue_request
588588
.prepare_steps()
589589
.iter()
590590
.map(|step| PrepareResp::new(*step.report_id(), PrepareStepResult::Finished))
591591
.collect()
592-
)
592+
}
593593
);
594594

595595
test_case.first_continue_response = Some(first_continue_response);

0 commit comments

Comments
 (0)