Skip to content

Commit 85c9c24

Browse files
committed
Remove MARK_AGGREGATED and INITIALIZE from ReportsProcessed
1 parent 85f905d commit 85c9c24

File tree

6 files changed

+208
-272
lines changed

6 files changed

+208
-272
lines changed

daphne/src/lib.rs

+48-19
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,32 @@ impl<T> Extend<(DapBatchBucket, (T, Vec<(ReportId, Time)>))> for DapAggregateSpa
366366
}
367367
}
368368

369+
impl FromIterator<(DapBatchBucket, (ReportId, Time))> for DapAggregateSpan<()> {
370+
fn from_iter<I>(iter: I) -> Self
371+
where
372+
I: IntoIterator<Item = (DapBatchBucket, (ReportId, Time))>,
373+
{
374+
let mut this = Self::default();
375+
this.extend(iter);
376+
this
377+
}
378+
}
379+
380+
impl Extend<(DapBatchBucket, (ReportId, Time))> for DapAggregateSpan<()> {
381+
fn extend<I>(&mut self, iter: I)
382+
where
383+
I: IntoIterator<Item = (DapBatchBucket, (ReportId, Time))>,
384+
{
385+
for (k, v) in iter {
386+
self.span
387+
.entry(k)
388+
.or_insert_with(|| ((), Vec::new()))
389+
.1
390+
.push(v);
391+
}
392+
}
393+
}
394+
369395
/// Per-task DAP parameters.
370396
#[derive(Clone, Deserialize, Serialize)]
371397
pub struct DapTaskConfig {
@@ -484,32 +510,35 @@ impl DapTaskConfig {
484510
&self,
485511
part_batch_sel: &'sel PartialBatchSelector,
486512
consumed_reports: impl Iterator<Item = &'rep EarlyReportStateConsumed<'rep>>,
487-
) -> Result<HashMap<DapBatchBucket, Vec<&'rep EarlyReportStateConsumed<'rep>>>, DapError> {
513+
) -> Result<DapAggregateSpan<()>, DapError> {
488514
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
489515
return Err(fatal_error!(
490516
err = "partial batch selector not compatible with task",
491517
));
492518
}
519+
Ok(consumed_reports
520+
.filter(|consumed_report| consumed_report.is_ready())
521+
.map(|consumed_report| {
522+
let bucket = self.bucket_for(part_batch_sel, consumed_report);
523+
let metadata = consumed_report.metadata();
524+
(bucket, (metadata.id.clone(), metadata.time))
525+
})
526+
.collect())
527+
}
493528

494-
let mut span: HashMap<_, Vec<_>> = HashMap::new();
495-
for consumed_report in consumed_reports.filter(|consumed_report| consumed_report.is_ready())
496-
{
497-
let bucket = match part_batch_sel {
498-
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
499-
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
500-
},
501-
PartialBatchSelector::FixedSizeByBatchId { batch_id } => {
502-
DapBatchBucket::FixedSize {
503-
batch_id: batch_id.clone(),
504-
}
505-
}
506-
};
507-
508-
let consumed_reports_per_bucket = span.entry(bucket).or_default();
509-
consumed_reports_per_bucket.push(consumed_report);
529+
pub fn bucket_for(
530+
&self,
531+
part_batch_sel: &PartialBatchSelector,
532+
consumed_report: &EarlyReportStateConsumed<'_>,
533+
) -> DapBatchBucket {
534+
match part_batch_sel {
535+
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
536+
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
537+
},
538+
PartialBatchSelector::FixedSizeByBatchId { batch_id } => DapBatchBucket::FixedSize {
539+
batch_id: batch_id.clone(),
540+
},
510541
}
511-
512-
Ok(span)
513542
}
514543

515544
/// Check if the batch size is too small. Returns an error if the report count is too large.

daphne/src/roles/mod.rs

+30-30
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod leader;
99

1010
use crate::{
1111
constants::DapMediaType,
12-
messages::{BatchSelector, ReportMetadata, TaskId, Time, TransitionFailure},
12+
messages::{BatchSelector, ReportMetadata, TaskId, Time},
1313
taskprov::{self, TaskprovVersion},
1414
DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig,
1515
};
@@ -103,34 +103,6 @@ async fn check_batch<S>(
103103
Ok(())
104104
}
105105

106-
/// Check for transition failures due to:
107-
///
108-
/// * the report having already been processed
109-
/// * the report having already been collected
110-
/// * the report not being within time bounds
111-
///
112-
/// Returns `Some(TransitionFailure)` if there is a problem,
113-
/// or `None` if no transition failure occurred.
114-
pub fn early_metadata_check(
115-
metadata: &ReportMetadata,
116-
processed: bool,
117-
collected: bool,
118-
min_time: u64,
119-
max_time: u64,
120-
) -> Option<TransitionFailure> {
121-
if processed {
122-
Some(TransitionFailure::ReportReplayed)
123-
} else if collected {
124-
Some(TransitionFailure::BatchCollected)
125-
} else if metadata.time < min_time {
126-
Some(TransitionFailure::ReportDropped)
127-
} else if metadata.time > max_time {
128-
Some(TransitionFailure::ReportTooEarly)
129-
} else {
130-
None
131-
}
132-
}
133-
134106
fn check_request_content_type<S>(
135107
req: &DapRequest<S>,
136108
expected: DapMediaType,
@@ -195,7 +167,7 @@ async fn resolve_taskprov<S>(
195167

196168
#[cfg(test)]
197169
mod test {
198-
use super::{early_metadata_check, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
170+
use super::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
199171
use crate::{
200172
assert_metrics_include, async_test_version, async_test_versions,
201173
auth::BearerToken,
@@ -234,6 +206,34 @@ mod test {
234206
}};
235207
}
236208

209+
/// Check for transition failures due to:
210+
///
211+
/// * the report having already been processed
212+
/// * the report having already been collected
213+
/// * the report not being within time bounds
214+
///
215+
/// Returns `Some(TransitionFailure)` if there is a problem,
216+
/// or `None` if no transition failure occurred.
217+
pub fn early_metadata_check(
218+
metadata: &ReportMetadata,
219+
processed: bool,
220+
collected: bool,
221+
min_time: u64,
222+
max_time: u64,
223+
) -> Option<TransitionFailure> {
224+
if processed {
225+
Some(TransitionFailure::ReportReplayed)
226+
} else if collected {
227+
Some(TransitionFailure::BatchCollected)
228+
} else if metadata.time < min_time {
229+
Some(TransitionFailure::ReportDropped)
230+
} else if metadata.time > max_time {
231+
Some(TransitionFailure::ReportTooEarly)
232+
} else {
233+
None
234+
}
235+
}
236+
237237
pub(super) struct TestData {
238238
pub now: Time,
239239
global_config: DapGlobalConfig,

daphne/src/testing.rs

+9-13
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::{
1313
AggregationJobContinueReq, AggregationJobId, AggregationJobInitReq, AggregationJobResp,
1414
BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq,
1515
Draft02AggregationJobId, HpkeCiphertext, Interval, PartialBatchSelector, Report, ReportId,
16-
ReportMetadata, TaskId, Time, TransitionFailure,
16+
TaskId, Time, TransitionFailure,
1717
},
1818
metrics::DaphneMetrics,
1919
roles::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader, DapReportInitializer},
@@ -697,7 +697,7 @@ impl MockAggregator {
697697
&self,
698698
task_id: &TaskId,
699699
bucket: &DapBatchBucket,
700-
metadata: &ReportMetadata,
700+
id: &ReportId,
701701
) -> Option<TransitionFailure> {
702702
// Check AggStateStore to see whether the report is part of a batch that has already
703703
// been collected.
@@ -713,7 +713,7 @@ impl MockAggregator {
713713
.lock()
714714
.expect("report_store: failed to lock");
715715
let report_store = guard.entry(task_id.clone()).or_default();
716-
if report_store.processed.contains(&metadata.id) {
716+
if report_store.processed.contains(id) {
717717
return Some(TransitionFailure::ReportReplayed);
718718
}
719719

@@ -920,17 +920,13 @@ impl DapReportInitializer for MockAggregator {
920920
)?;
921921

922922
let mut early_fails = HashMap::new();
923-
for (bucket, reports_consumed_per_bucket) in span.iter() {
924-
for metadata in reports_consumed_per_bucket
925-
.iter()
926-
.map(|report| report.metadata())
927-
{
923+
for (bucket, ((), report_ids_and_time)) in span.iter() {
924+
for (id, _) in report_ids_and_time {
928925
// Check whether Report has been collected or replayed.
929-
if let Some(transition_failure) = self
930-
.check_report_early_fail(task_id, bucket, metadata)
931-
.await
926+
if let Some(transition_failure) =
927+
self.check_report_early_fail(task_id, bucket, id).await
932928
{
933-
early_fails.insert(metadata.id.clone(), transition_failure);
929+
early_fails.insert(id.clone(), transition_failure);
934930
};
935931
}
936932
}
@@ -1233,7 +1229,7 @@ impl DapLeader<BearerToken> for MockAggregator {
12331229

12341230
// Check whether Report has been collected or replayed.
12351231
if let Some(transition_failure) = self
1236-
.check_report_early_fail(task_id, &bucket, &report.report_metadata)
1232+
.check_report_early_fail(task_id, &bucket, &report.report_metadata.id)
12371233
.await
12381234
{
12391235
return Err(DapError::Transition(transition_failure));

daphne/src/vdaf/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ impl<'req> EarlyReportStateConsumed<'req> {
187187
})
188188
}
189189

190+
/// Convert this EarlyReportStateConsumed into a rejected [EarlyReportStateInitialized] using
191+
/// `failure` as the reason. If this is already a rejected report, the passed in `failure`
192+
/// value overwrites the previous one.
190193
pub fn into_initialized_rejected_due_to(
191194
self,
192195
failure: TransitionFailure,

daphne_worker/src/durable/reports_processed.rs

+21-90
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@ use daphne::{
1414
},
1515
DapError, VdafConfig,
1616
};
17-
use futures::{
18-
future::{ready, try_join_all},
19-
StreamExt, TryStreamExt,
20-
};
17+
use futures::{future::try_join_all, StreamExt, TryStreamExt};
2118
use prio::codec::{CodecError, ParameterizedDecode};
2219
use serde::{Deserialize, Serialize};
23-
use std::{borrow::Cow, collections::HashSet, ops::ControlFlow, time::Duration};
20+
use std::{borrow::Cow, collections::HashSet, future::ready, ops::ControlFlow, time::Duration};
2421
use tracing::Instrument;
2522
use worker::*;
2623

27-
use super::{req_parse, Alarmed, DapDurableObject, GarbageCollectable};
24+
use super::{req_parse, state_set_if_not_exists, Alarmed, DapDurableObject, GarbageCollectable};
2825

2926
pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZE: &str =
3027
"/internal/do/reports_processed/initialize";
31-
pub(crate) const DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED: &str =
32-
"/internal/do/reports_processed/mark_aggregated";
28+
pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZED: &str =
29+
"/internal/do/reports_processed/initialized";
3330

3431
/// Durable Object (DO) for tracking which reports have been processed.
3532
///
@@ -63,63 +60,6 @@ impl<'id> From<&'id ReportId> for ReportIdKey<'id> {
6360
}
6461
}
6562

66-
#[derive(Debug)]
67-
enum CheckedReplays<'s> {
68-
SomeReplayed(Vec<&'s ReportId>),
69-
AllFresh(Vec<ReportIdKey<'s>>),
70-
}
71-
72-
impl<'r> Default for CheckedReplays<'r> {
73-
fn default() -> Self {
74-
Self::AllFresh(vec![])
75-
}
76-
}
77-
78-
impl<'r> CheckedReplays<'r> {
79-
fn add_replay(mut self, id: &'r ReportId) -> Self {
80-
match &mut self {
81-
Self::SomeReplayed(r) => {
82-
r.push(id);
83-
self
84-
}
85-
Self::AllFresh(_) => Self::SomeReplayed(vec![id]),
86-
}
87-
}
88-
89-
fn add_fresh(mut self, id: ReportIdKey<'r>) -> Self {
90-
match &mut self {
91-
Self::SomeReplayed(_) => {}
92-
Self::AllFresh(r) => r.push(id),
93-
}
94-
self
95-
}
96-
}
97-
98-
impl ReportsProcessed {
99-
async fn check_replays<'s>(&self, report_ids: &'s [ReportId]) -> Result<CheckedReplays<'s>> {
100-
futures::stream::iter(report_ids.iter().map(ReportIdKey::from))
101-
.then(|id| {
102-
let state = &self.state;
103-
async move {
104-
state_get::<bool>(state, &id.1)
105-
.await
106-
.map(|presence| match presence {
107-
// if it's present then it's a replay
108-
Some(true) => Err(id.0),
109-
Some(false) | None => Ok(id),
110-
})
111-
}
112-
})
113-
.try_fold(CheckedReplays::default(), |acc, id| async move {
114-
Ok(match id {
115-
Ok(not_replayed) => acc.add_fresh(not_replayed),
116-
Err(replayed) => acc.add_replay(replayed),
117-
})
118-
})
119-
.await
120-
}
121-
}
122-
12363
#[durable_object]
12464
impl DurableObject for ReportsProcessed {
12565
fn new(state: State, env: Env) -> Self {
@@ -166,6 +106,22 @@ impl ReportsProcessed {
166106
.await?;
167107

168108
match (req.path().as_ref(), req.method()) {
109+
(DURABLE_REPORTS_PROCESSED_INITIALIZED, Method::Post) => {
110+
let to_mark = req_parse::<Vec<ReportId>>(&mut req).await?;
111+
let state = &self.state;
112+
let replays = futures::stream::iter(&to_mark)
113+
.map(|id| async move {
114+
state_set_if_not_exists(state, &format!("processed/{id}"), &true)
115+
.await
116+
.map(|o| o.is_some().then_some(id))
117+
})
118+
.buffer_unordered(usize::MAX)
119+
.try_filter_map(|replay| ready(Ok(replay)))
120+
.try_collect::<Vec<_>>()
121+
.await?;
122+
123+
Response::from_json(&replays)
124+
}
169125
// Initialize a report:
170126
// * Ensure the report wasn't replayed
171127
// * Ensure the report won't be included in a batch that was already collected
@@ -230,31 +186,6 @@ impl ReportsProcessed {
230186
})
231187
}
232188

233-
// Mark reports as aggregated.
234-
//
235-
// If there are any replays, no reports are marked as aggregated.
236-
//
237-
// Idempotent
238-
// Input: `Vec<ReportId>`
239-
// Output: `Vec<ReportId>`
240-
(DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED, Method::Post) => {
241-
let report_ids: Vec<ReportId> = req_parse(&mut req).await?;
242-
match self.check_replays(&report_ids).await? {
243-
CheckedReplays::SomeReplayed(report_ids) => Response::from_json(&report_ids),
244-
CheckedReplays::AllFresh(report_ids) => {
245-
let state = &self.state;
246-
futures::stream::iter(&report_ids)
247-
.then(|report_id| async move {
248-
state.storage().put(&report_id.1, &true).await
249-
})
250-
.try_for_each(|_| ready(Ok(())))
251-
.await?;
252-
253-
Response::from_json(&[(); 0])
254-
}
255-
}
256-
}
257-
258189
_ => Err(int_err(format!(
259190
"ReportsProcessed: unexpected request: method={:?}; path={:?}",
260191
req.method(),

0 commit comments

Comments
 (0)