Skip to content

Commit

Permalink
Remove MARK_AGGREGATED and INITIALIZE from ReportsProcessed
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Oct 26, 2023
1 parent 85f905d commit f2d01de
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 257 deletions.
32 changes: 19 additions & 13 deletions daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ use constants::DapMediaType;
use criterion as _;
pub use error::DapError;
use hpke::{HpkeConfig, HpkeKemId};
use messages::ReportMetadata;
use prio::{
codec::{Decode, Encode, ParameterizedDecode},
vdaf::Aggregatable as AggregatableTrait,
Expand Down Expand Up @@ -484,7 +485,7 @@ impl DapTaskConfig {
&self,
part_batch_sel: &'sel PartialBatchSelector,
consumed_reports: impl Iterator<Item = &'rep EarlyReportStateConsumed<'rep>>,
) -> Result<HashMap<DapBatchBucket, Vec<&'rep EarlyReportStateConsumed<'rep>>>, DapError> {
) -> Result<HashMap<DapBatchBucket, Vec<ReportMetadata>>, DapError> {
if !self.query.is_valid_part_batch_sel(part_batch_sel) {
return Err(fatal_error!(
err = "partial batch selector not compatible with task",
Expand All @@ -494,24 +495,29 @@ impl DapTaskConfig {
let mut span: HashMap<_, Vec<_>> = HashMap::new();
for consumed_report in consumed_reports.filter(|consumed_report| consumed_report.is_ready())
{
let bucket = match part_batch_sel {
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
},
PartialBatchSelector::FixedSizeByBatchId { batch_id } => {
DapBatchBucket::FixedSize {
batch_id: batch_id.clone(),
}
}
};

let bucket = self.bucket_for(part_batch_sel, consumed_report);
let consumed_reports_per_bucket = span.entry(bucket).or_default();
consumed_reports_per_bucket.push(consumed_report);
consumed_reports_per_bucket.push(consumed_report.metadata().clone());
}

Ok(span)
}

pub fn bucket_for(
&self,
part_batch_sel: &PartialBatchSelector,
consumed_report: &EarlyReportStateConsumed<'_>,
) -> DapBatchBucket {
match part_batch_sel {
PartialBatchSelector::TimeInterval => DapBatchBucket::TimeInterval {
batch_window: self.quantized_time_lower_bound(consumed_report.metadata().time),
},
PartialBatchSelector::FixedSizeByBatchId { batch_id } => DapBatchBucket::FixedSize {
batch_id: batch_id.clone(),
},
}
}

/// Check if the batch size is too small. Returns an error if the report count is too large.
pub(crate) fn is_report_count_compatible(
&self,
Expand Down
60 changes: 30 additions & 30 deletions daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ mod leader;

use crate::{
constants::DapMediaType,
messages::{BatchSelector, ReportMetadata, TaskId, Time, TransitionFailure},
messages::{BatchSelector, ReportMetadata, TaskId, Time},
taskprov::{self, TaskprovVersion},
DapAbort, DapError, DapQueryConfig, DapRequest, DapTaskConfig,
};
Expand Down Expand Up @@ -103,34 +103,6 @@ async fn check_batch<S>(
Ok(())
}

/// Check for transition failures due to:
///
/// * the report having already been processed
/// * the report having already been collected
/// * the report not being within time bounds
///
/// Returns `Some(TransitionFailure)` if there is a problem,
/// or `None` if no transition failure occurred.
pub fn early_metadata_check(
metadata: &ReportMetadata,
processed: bool,
collected: bool,
min_time: u64,
max_time: u64,
) -> Option<TransitionFailure> {
if processed {
Some(TransitionFailure::ReportReplayed)
} else if collected {
Some(TransitionFailure::BatchCollected)
} else if metadata.time < min_time {
Some(TransitionFailure::ReportDropped)
} else if metadata.time > max_time {
Some(TransitionFailure::ReportTooEarly)
} else {
None
}
}

fn check_request_content_type<S>(
req: &DapRequest<S>,
expected: DapMediaType,
Expand Down Expand Up @@ -195,7 +167,7 @@ async fn resolve_taskprov<S>(

#[cfg(test)]
mod test {
use super::{early_metadata_check, DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
use super::{DapAggregator, DapAuthorizedSender, DapHelper, DapLeader};
use crate::{
assert_metrics_include, async_test_version, async_test_versions,
auth::BearerToken,
Expand Down Expand Up @@ -234,6 +206,34 @@ mod test {
}};
}

/// Check for transition failures due to:
///
/// * the report having already been processed
/// * the report having already been collected
/// * the report not being within time bounds
///
/// Returns `Some(TransitionFailure)` if there is a problem,
/// or `None` if no transition failure occurred.
pub fn early_metadata_check(
metadata: &ReportMetadata,
processed: bool,
collected: bool,
min_time: u64,
max_time: u64,
) -> Option<TransitionFailure> {
if processed {
Some(TransitionFailure::ReportReplayed)
} else if collected {
Some(TransitionFailure::BatchCollected)
} else if metadata.time < min_time {
Some(TransitionFailure::ReportDropped)
} else if metadata.time > max_time {
Some(TransitionFailure::ReportTooEarly)
} else {
None
}
}

pub(super) struct TestData {
pub now: Time,
global_config: DapGlobalConfig,
Expand Down
5 changes: 1 addition & 4 deletions daphne/src/testing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -921,10 +921,7 @@ impl DapReportInitializer for MockAggregator {

let mut early_fails = HashMap::new();
for (bucket, reports_consumed_per_bucket) in span.iter() {
for metadata in reports_consumed_per_bucket
.iter()
.map(|report| report.metadata())
{
for metadata in reports_consumed_per_bucket.iter() {
// Check whether Report has been collected or replayed.
if let Some(transition_failure) = self
.check_report_early_fail(task_id, bucket, metadata)
Expand Down
111 changes: 21 additions & 90 deletions daphne_worker/src/durable/reports_processed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,19 @@ use daphne::{
},
DapError, VdafConfig,
};
use futures::{
future::{ready, try_join_all},
StreamExt, TryStreamExt,
};
use futures::{future::try_join_all, StreamExt, TryStreamExt};
use prio::codec::{CodecError, ParameterizedDecode};
use serde::{Deserialize, Serialize};
use std::{borrow::Cow, collections::HashSet, ops::ControlFlow, time::Duration};
use std::{borrow::Cow, collections::HashSet, future::ready, ops::ControlFlow, time::Duration};
use tracing::Instrument;
use worker::*;

use super::{req_parse, Alarmed, DapDurableObject, GarbageCollectable};
use super::{req_parse, state_set_if_not_exists, Alarmed, DapDurableObject, GarbageCollectable};

pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZE: &str =
"/internal/do/reports_processed/initialize";
pub(crate) const DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED: &str =
"/internal/do/reports_processed/mark_aggregated";
pub(crate) const DURABLE_REPORTS_PROCESSED_INITIALIZED: &str =
"/internal/do/reports_processed/initialized";

/// Durable Object (DO) for tracking which reports have been processed.
///
Expand Down Expand Up @@ -63,63 +60,6 @@ impl<'id> From<&'id ReportId> for ReportIdKey<'id> {
}
}

#[derive(Debug)]
enum CheckedReplays<'s> {
SomeReplayed(Vec<&'s ReportId>),
AllFresh(Vec<ReportIdKey<'s>>),
}

impl<'r> Default for CheckedReplays<'r> {
fn default() -> Self {
Self::AllFresh(vec![])
}
}

impl<'r> CheckedReplays<'r> {
fn add_replay(mut self, id: &'r ReportId) -> Self {
match &mut self {
Self::SomeReplayed(r) => {
r.push(id);
self
}
Self::AllFresh(_) => Self::SomeReplayed(vec![id]),
}
}

fn add_fresh(mut self, id: ReportIdKey<'r>) -> Self {
match &mut self {
Self::SomeReplayed(_) => {}
Self::AllFresh(r) => r.push(id),
}
self
}
}

impl ReportsProcessed {
async fn check_replays<'s>(&self, report_ids: &'s [ReportId]) -> Result<CheckedReplays<'s>> {
futures::stream::iter(report_ids.iter().map(ReportIdKey::from))
.then(|id| {
let state = &self.state;
async move {
state_get::<bool>(state, &id.1)
.await
.map(|presence| match presence {
// if it's present then it's a replay
Some(true) => Err(id.0),
Some(false) | None => Ok(id),
})
}
})
.try_fold(CheckedReplays::default(), |acc, id| async move {
Ok(match id {
Ok(not_replayed) => acc.add_fresh(not_replayed),
Err(replayed) => acc.add_replay(replayed),
})
})
.await
}
}

#[durable_object]
impl DurableObject for ReportsProcessed {
fn new(state: State, env: Env) -> Self {
Expand Down Expand Up @@ -166,6 +106,22 @@ impl ReportsProcessed {
.await?;

match (req.path().as_ref(), req.method()) {
(DURABLE_REPORTS_PROCESSED_INITIALIZED, Method::Post) => {
let to_mark = req_parse::<Vec<ReportId>>(&mut req).await?;
let state = &self.state;
let replays = futures::stream::iter(&to_mark)
.map(|id| async move {
state_set_if_not_exists(state, &format!("processed/{id}"), &true)
.await
.map(|o| o.is_some().then_some(id))
})
.buffer_unordered(usize::MAX)
.try_filter_map(|replay| ready(Ok(replay)))
.try_collect::<Vec<_>>()
.await?;

Response::from_json(&replays)
}
// Initialize a report:
// * Ensure the report wasn't replayed
// * Ensure the report won't be included in a batch that was already collected
Expand Down Expand Up @@ -230,31 +186,6 @@ impl ReportsProcessed {
})
}

// Mark reports as aggregated.
//
// If there are any replays, no reports are marked as aggregated.
//
// Idempotent
// Input: `Vec<ReportId>`
// Output: `Vec<ReportId>`
(DURABLE_REPORTS_PROCESSED_MARK_AGGREGATED, Method::Post) => {
let report_ids: Vec<ReportId> = req_parse(&mut req).await?;
match self.check_replays(&report_ids).await? {
CheckedReplays::SomeReplayed(report_ids) => Response::from_json(&report_ids),
CheckedReplays::AllFresh(report_ids) => {
let state = &self.state;
futures::stream::iter(&report_ids)
.then(|report_id| async move {
state.storage().put(&report_id.1, &true).await
})
.try_for_each(|_| ready(Ok(())))
.await?;

Response::from_json(&[(); 0])
}
}
}

_ => Err(int_err(format!(
"ReportsProcessed: unexpected request: method={:?}; path={:?}",
req.method(),
Expand Down
Loading

0 comments on commit f2d01de

Please sign in to comment.