Skip to content

Commit 1bc774f

Browse files
committed
Remove MARK_AGGREGATED and INITIALIZE from ReportsProcessed
1 parent 85f905d commit 1bc774f

File tree

5 files changed

+171
-262
lines changed

5 files changed

+171
-262
lines changed

daphne/src/lib.rs

+19-13
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use constants::DapMediaType;
6767
use criterion as _;
6868
pub use error::DapError;
6969
use hpke::{HpkeConfig, HpkeKemId};
70+
use messages::ReportMetadata;
7071
use prio::{
7172
codec::{Decode, Encode, ParameterizedDecode},
7273
vdaf::Aggregatable as AggregatableTrait,
@@ -484,7 +485,7 @@ impl DapTaskConfig {
484485
&self,
485486
part_batch_sel: &'sel PartialBatchSelector,
486487
consumed_reports: impl Iterator<Item = &'rep EarlyReportStateConsumed<'rep>>,
487-
) -> Result<HashMap<DapBatchBucket, Vec<&'rep EarlyReportStateConsumed<'rep>>>, DapError> {
488+
) -> Result<HashMap<DapBatchBucket, Vec<ReportMetadata>>, DapError> {
488489
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
489490
return Err(fatal_error!(
490491
err = "partial batch selector not compatible with task",
@@ -494,24 +495,29 @@ impl DapTaskConfig {
494495
let mut span: HashMap<_, Vec<_>> = HashMap::new();
495496
for consumed_report in consumed_reports.filter(|consumed_report| consumed_report.is_ready())
496497
{
497-
let bucket = match part_batch_sel {
498-
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
499-
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
500-
},
501-
PartialBatchSelector::FixedSizeByBatchId { batch_id } => {
502-
DapBatchBucket::FixedSize {
503-
batch_id: batch_id.clone(),
504-
}
505-
}
506-
};
507-
498+
let bucket = self.bucket_for(part_batch_sel, consumed_report);
508499
let consumed_reports_per_bucket = span.entry(bucket).or_default();
509-
consumed_reports_per_bucket.push(consumed_report);
500+
consumed_reports_per_bucket.push(consumed_report.metadata().clone());
510501
}
511502

512503
Ok(span)
513504
}
514505

506+
pub fn bucket_for(
507+
&self,
508+
part_batch_sel: &PartialBatchSelector,
509+
consumed_report: &EarlyReportStateConsumed<'_>,
510+
) -> DapBatchBucket {
511+
match part_batch_sel {
512+
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
513+
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
514+
},
515+
PartialBatchSelector::FixedSizeByBatchId { batch_id } => DapBatchBucket::FixedSize {
516+
batch_id: batch_id.clone(),
517+
},
518+
}
519+
}
520+
515521
/// Check if the batch size is too small. Returns an error if the report count is too large.
516522
pub(crate) fn is_report_count_compatible(
517523
&self,

daphne/src/roles/mod.rs

+30-30
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ mod leader;
99

1010
use crate::{
1111
constants::DapMediaType,
12-
messages::{BatchSelector, ReportMetadata, TaskId, Time, TransitionFailure},
12+
messages::{BatchSelector, ReportMetadata, TaskId, Time},
1313
taskprov::{self, TaskprovVersion},
1414
DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig,
1515
};
@@ -103,34 +103,6 @@ async fn check_batch<S>(
103103
Ok(())
104104
}
105105

106-
/// Check for transition failures due to:
107-
///
108-
/// * the report having already been processed
109-
/// * the report having already been collected
110-
/// * the report not being within time bounds
111-
///
112-
/// Returns `Some(TransitionFailure)` if there is a problem,
113-
/// or `None` if no transition failure occurred.
114-
pub fn early_metadata_check(
115-
metadata: &ReportMetadata,
116-
processed: bool,
117-
collected: bool,
118-
min_time: u64,
119-
max_time: u64,
120-
) -> Option<TransitionFailure> {
121-
if processed {
122-
Some(TransitionFailure::ReportReplayed)
123-
} else if collected {
124-
Some(TransitionFailure::BatchCollected)
125-
} else if metadata.time < min_time {
126-
Some(TransitionFailure::ReportDropped)
127-
} else if metadata.time > max_time {
128-
Some(TransitionFailure::ReportTooEarly)
129-
} else {
130-
None
131-
}
132-
}
133-
134106
fn check_request_content_type<S>(
135107
req: &DapRequest<S>,
136108
expected: DapMediaType,
@@ -195,7 +167,7 @@ async fn resolve_taskprov<S>(
195167

196168
#[cfg(test)]
197169
mod test {
198-
use super::{early_metadata_check, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
170+
use super::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
199171
use crate::{
200172
assert_metrics_include, async_test_version, async_test_versions,
201173
auth::BearerToken,
@@ -234,6 +206,34 @@ mod test {
234206
}};
235207
}
236208

209+
/// Check for transition failures due to:
210+
///
211+
/// * the report having already been processed
212+
/// * the report having already been collected
213+
/// * the report not being within time bounds
214+
///
215+
/// Returns `Some(TransitionFailure)` if there is a problem,
216+
/// or `None` if no transition failure occurred.
217+
pub fn early_metadata_check(
218+
metadata: &ReportMetadata,
219+
processed: bool,
220+
collected: bool,
221+
min_time: u64,
222+
max_time: u64,
223+
) -> Option<TransitionFailure> {
224+
if processed {
225+
Some(TransitionFailure::ReportReplayed)
226+
} else if collected {
227+
Some(TransitionFailure::BatchCollected)
228+
} else if metadata.time < min_time {
229+
Some(TransitionFailure::ReportDropped)
230+
} else if metadata.time > max_time {
231+
Some(TransitionFailure::ReportTooEarly)
232+
} else {
233+
None
234+
}
235+
}
236+
237237
pub(super) struct TestData {
238238
pub now: Time,
239239
global_config: DapGlobalConfig,

daphne/src/testing.rs

+1-4
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,7 @@ impl DapReportInitializer for MockAggregator {
921921

922922
let mut early_fails = HashMap::new();
923923
for (bucket, reports_consumed_per_bucket) in span.iter() {
924-
for metadata in reports_consumed_per_bucket
925-
.iter()
926-
.map(|report| report.metadata())
927-
{
924+
for metadata in reports_consumed_per_bucket.iter() {
928925
// Check whether Report has been collected or replayed.
929926
if let Some(transition_failure) = self
930927
.check_report_early_fail(task_id, bucket, metadata)

daphne_worker/src/durable/reports_processed.rs

+21-90
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,19 @@ use daphne::{
1414
},
1515
DapError, VdafConfig,
1616
};
17-
use futures::{
18-
future::{ready, try_join_all},
19-
StreamExt, TryStreamExt,
20-
};
17+
use futures::{future::try_join_all, StreamExt, TryStreamExt};
2118
use prio::codec::{CodecError, ParameterizedDecode};
2219
use serde::{Deserialize, Serialize};
23-
use std::{borrow::Cow, collections::HashSet, ops::ControlFlow, time::Duration};
20+
use std::{borrow::Cow, collections::HashSet, future::ready, ops::ControlFlow, time::Duration};
2421
use tracing::Instrument;
2522
use worker::*;
2623

27-
use super::{req_parse, Alarmed, DapDurableObject, GarbageCollectable};
24+
use super::{req_parse, state_set_if_not_exists, Alarmed, DapDurableObject, GarbageCollectable};
2825

2926
pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZE: &str =
3027
"/internal/do/reports_processed/initialize";
31-
pub(crate) const DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED: &str =
32-
"/internal/do/reports_processed/mark_aggregated";
28+
pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZED: &str =
29+
"/internal/do/reports_processed/initialized";
3330

3431
/// Durable Object (DO) for tracking which reports have been processed.
3532
///
@@ -63,63 +60,6 @@ impl<'id> From<&'id ReportId> for ReportIdKey<'id> {
6360
}
6461
}
6562

66-
#[derive(Debug)]
67-
enum CheckedReplays<'s> {
68-
SomeReplayed(Vec<&'s ReportId>),
69-
AllFresh(Vec<ReportIdKey<'s>>),
70-
}
71-
72-
impl<'r> Default for CheckedReplays<'r> {
73-
fn default() -> Self {
74-
Self::AllFresh(vec![])
75-
}
76-
}
77-
78-
impl<'r> CheckedReplays<'r> {
79-
fn add_replay(mut self, id: &'r ReportId) -> Self {
80-
match &mut self {
81-
Self::SomeReplayed(r) => {
82-
r.push(id);
83-
self
84-
}
85-
Self::AllFresh(_) => Self::SomeReplayed(vec![id]),
86-
}
87-
}
88-
89-
fn add_fresh(mut self, id: ReportIdKey<'r>) -> Self {
90-
match &mut self {
91-
Self::SomeReplayed(_) => {}
92-
Self::AllFresh(r) => r.push(id),
93-
}
94-
self
95-
}
96-
}
97-
98-
impl ReportsProcessed {
99-
async fn check_replays<'s>(&self, report_ids: &'s [ReportId]) -> Result<CheckedReplays<'s>> {
100-
futures::stream::iter(report_ids.iter().map(ReportIdKey::from))
101-
.then(|id| {
102-
let state = &self.state;
103-
async move {
104-
state_get::<bool>(state, &id.1)
105-
.await
106-
.map(|presence| match presence {
107-
// if it's present then it's a replay
108-
Some(true) => Err(id.0),
109-
Some(false) | None => Ok(id),
110-
})
111-
}
112-
})
113-
.try_fold(CheckedReplays::default(), |acc, id| async move {
114-
Ok(match id {
115-
Ok(not_replayed) => acc.add_fresh(not_replayed),
116-
Err(replayed) => acc.add_replay(replayed),
117-
})
118-
})
119-
.await
120-
}
121-
}
122-
12363
#[durable_object]
12464
impl DurableObject for ReportsProcessed {
12565
fn new(state: State, env: Env) -> Self {
@@ -166,6 +106,22 @@ impl ReportsProcessed {
166106
.await?;
167107

168108
match (req.path().as_ref(), req.method()) {
109+
(DURABLE_REPORTS_PROCESSED_INITIALIZED, Method::Post) => {
110+
let to_mark = req_parse::<Vec<ReportId>>(&mut req).await?;
111+
let state = &self.state;
112+
let replays = futures::stream::iter(&to_mark)
113+
.map(|id| async move {
114+
state_set_if_not_exists(state, &format!("processed/{id}"), &true)
115+
.await
116+
.map(|o| o.is_some().then_some(id))
117+
})
118+
.buffer_unordered(usize::MAX)
119+
.try_filter_map(|replay| ready(Ok(replay)))
120+
.try_collect::<Vec<_>>()
121+
.await?;
122+
123+
Response::from_json(&replays)
124+
}
169125
// Initialize a report:
170126
// * Ensure the report wasn't replayed
171127
// * Ensure the report won't be included in a batch that was already collected
@@ -230,31 +186,6 @@ impl ReportsProcessed {
230186
})
231187
}
232188

233-
// Mark reports as aggregated.
234-
//
235-
// If there are any replays, no reports are marked as aggregated.
236-
//
237-
// Idempotent
238-
// Input: `Vec<ReportId>`
239-
// Output: `Vec<ReportId>`
240-
(DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED, Method::Post) => {
241-
let report_ids: Vec<ReportId> = req_parse(&mut req).await?;
242-
match self.check_replays(&report_ids).await? {
243-
CheckedReplays::SomeReplayed(report_ids) => Response::from_json(&report_ids),
244-
CheckedReplays::AllFresh(report_ids) => {
245-
let state = &self.state;
246-
futures::stream::iter(&report_ids)
247-
.then(|report_id| async move {
248-
state.storage().put(&report_id.1, &true).await
249-
})
250-
.try_for_each(|_| ready(Ok(())))
251-
.await?;
252-
253-
Response::from_json(&[(); 0])
254-
}
255-
}
256-
}
257-
258189
_ => Err(int_err(format!(
259190
"ReportsProcessed: unexpected request: method={:?}; path={:?}",
260191
req.method(),

0 commit comments

Comments
 (0)