Skip to content

Commit

Permalink
Improve type safeness of InitializedReport
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Dec 3, 2024
1 parent e27420f commit 69970f0
Show file tree
Hide file tree
Showing 5 changed files with 374 additions and 280 deletions.
2 changes: 1 addition & 1 deletion crates/daphne/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ use url::Url;
use vdaf::mastic::MasticWeight;

pub use messages::request::{DapRequest, DapRequestMeta, DapResponse};
pub use protocol::aggregator::InitializedReport;
pub use protocol::report_init::{InitializedReport, WithPeerPrepShare};

/// DAP version used for a task.
#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, Hash, PartialEq, Serialize)]
Expand Down
283 changes: 20 additions & 263 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
@@ -1,257 +1,40 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use super::{
no_duplicates,
report_init::{InitializedReport, WithPeerPrepShare},
};
#[cfg(feature = "experimental")]
use crate::vdaf::mastic::{mastic_prep_finish, mastic_prep_finish_from_shares, mastic_prep_init};
use crate::vdaf::mastic::{mastic_prep_finish, mastic_prep_finish_from_shares};
use crate::{
constants::DapAggregatorRole,
error::DapAbort,
fatal_error,
hpke::{info_and_aad, HpkeConfig, HpkeDecrypter},
messages::{
self, encode_u32_bytes, AggregationJobInitReq, AggregationJobResp, Base64Encode,
BatchSelector, Extension, HpkeCiphertext, PartialBatchSelector, PlaintextInputShare,
PrepareInit, Report, ReportId, ReportMetadata, ReportShare, TaskId, Transition,
TransitionFailure, TransitionVar,
BatchSelector, HpkeCiphertext, PartialBatchSelector, PrepareInit, Report, ReportId,
ReportShare, TaskId, Transition, TransitionFailure, TransitionVar,
},
metrics::{DaphneMetrics, ReportStatus},
protocol::{decode_ping_pong_framed, PingPongMessageType},
vdaf::{
prio2::{prio2_prep_finish, prio2_prep_finish_from_shares, prio2_prep_init},
prio3::{prio3_prep_finish, prio3_prep_finish_from_shares, prio3_prep_init},
VdafError, VdafPrepShare, VdafPrepState,
prio2::{prio2_prep_finish, prio2_prep_finish_from_shares},
prio3::{prio3_prep_finish, prio3_prep_finish_from_shares},
VdafError,
},
AggregationJobReportState, DapAggregateShare, DapAggregateSpan, DapAggregationJobState,
DapAggregationParam, DapError, DapTaskConfig, DapVersion, VdafConfig,
};
use prio::codec::{
encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode, ParameterizedEncode,
};
use prio::codec::{encode_u32_items, Encode, ParameterizedDecode, ParameterizedEncode};
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use std::{
collections::{HashMap, HashSet},
io::Cursor,
iter::zip,
ops::Range,
};

use rayon::iter::{IntoParallelIterator, ParallelIterator as _};

// Ping-pong message framing as defined in draft-irtf-cfrg-vdaf-08, Section 5.8. We do not
// implement the "continue" message type because we only support 1-round VDAFs.
enum PingPongMessageType {
Initialize = 0,
Finish = 2,
}

// This is essentially a re-implementation of a method in the `messages` module. However the goal
// here is to make it zero-copy. See https://github.com/cloudflare/daphne/issues/15.
fn decode_ping_pong_framed(
bytes: &[u8],
expected_type: PingPongMessageType,
) -> Result<&[u8], CodecError> {
let mut r = Cursor::new(bytes);

let message_type = u8::decode(&mut r)?;
if message_type != expected_type as u8 {
return Err(CodecError::UnexpectedValue);
}

let message_len = u32::decode(&mut r)?.try_into().unwrap();
let message_start = usize::try_from(r.position()).unwrap();
if bytes.len() - message_start < message_len {
return Err(CodecError::LengthPrefixTooBig(message_len));
}
if bytes.len() - message_start > message_len {
return Err(CodecError::BytesLeftOver(message_len));
}

Ok(&bytes[message_start..])
}

/// Report state during aggregation initialization after the VDAF preparation step.
#[expect(clippy::large_enum_variant)]
#[derive(Clone)]
#[cfg_attr(any(test, feature = "test-utils"), derive(Debug, deepsize::DeepSizeOf))]
pub enum InitializedReport {
Ready {
metadata: ReportMetadata,
public_share: Vec<u8>,
prep_share: VdafPrepShare,
prep_state: VdafPrepState,
// Set by the Helper.
peer_prep_share: Option<Vec<u8>>,
},
Rejected {
metadata: ReportMetadata,
failure: TransitionFailure,
},
}

impl InitializedReport {
#[expect(clippy::too_many_arguments)]
pub(crate) fn new(
decrypter: &impl HpkeDecrypter,
valid_report_range: Range<messages::Time>,
role: DapAggregatorRole,
task_id: &TaskId,
task_config: &DapTaskConfig,
report_share: ReportShare,
prep_init_payload: Option<Vec<u8>>,
// We need to use this variable for Mastic, which is currently fenced by the
// "experimental" feature.
#[cfg_attr(not(feature = "experimental"), expect(unused_variables))]
agg_param: &DapAggregationParam,
) -> Result<Self, DapError> {
macro_rules! reject {
($failure:ident) => {
return Ok(Self::Rejected {
metadata: report_share.report_metadata,
failure: TransitionFailure::$failure,
})
};
}
match report_share.report_metadata.time {
t if t >= task_config.not_after => reject!(TaskExpired),
t if t < valid_report_range.start => reject!(ReportDropped),
t if valid_report_range.end < t => reject!(ReportTooEarly),
_ => {}
}

// decrypt input share
let PlaintextInputShare {
extensions,
payload: input_share,
} = {
let info = info_and_aad::InputShare {
version: task_config.version,
receiver: role,
task_id,
report_metadata: &report_share.report_metadata,
public_share: &report_share.public_share,
};

let encoded_input_share =
match decrypter.hpke_decrypt(info, &report_share.encrypted_input_share) {
Ok(encoded_input_share) => encoded_input_share,
Err(DapError::Transition(failure)) => {
return Ok(Self::Rejected {
metadata: report_share.report_metadata,
failure,
})
}
Err(e) => return Err(e),
};

let Ok(plaintext) = PlaintextInputShare::get_decoded_with_param(
&task_config.version,
&encoded_input_share,
) else {
reject!(InvalidMessage)
};
plaintext
};

// Handle report extensions.
{
if no_duplicates(extensions.iter().map(|e| e.type_code())).is_err() {
reject!(InvalidMessage)
}
let mut taskprov_indicated = false;
for extension in extensions {
match extension {
Extension::Taskprov { .. } if task_config.method_is_taskprov() => {
taskprov_indicated = true;
}

// Reject reports with unrecognized extensions.
_ => reject!(InvalidMessage),
}
}

if task_config.method_is_taskprov() && !taskprov_indicated {
// taskprov: If the task configuration method is taskprov, then we expect each
// report to indicate support.
reject!(InvalidMessage);
}
}

// Decode the ping-pong "initialize" message framing.
// (draft-irtf-cfrg-vdaf-08, Section 5.8).
let peer_prep_share = match prep_init_payload
.as_ref()
.map(|payload| decode_ping_pong_framed(payload, PingPongMessageType::Initialize))
.transpose()
{
Ok(peer_prep_share) => peer_prep_share,
Err(e) => {
tracing::warn!(error = ?e, "rejecting report");
reject!(VdafPrepError);
}
};

let agg_id = match role {
DapAggregatorRole::Leader => 0,
DapAggregatorRole::Helper => 1,
};
let res = match &task_config.vdaf {
VdafConfig::Prio3(ref prio3_config) => prio3_prep_init(
prio3_config,
&task_config.vdaf_verify_key,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
VdafConfig::Prio2 { dimension } => prio2_prep_init(
*dimension,
&task_config.vdaf_verify_key,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
#[cfg(feature = "experimental")]
VdafConfig::Mastic {
input_size,
weight_config,
} => mastic_prep_init(
*input_size,
*weight_config,
&task_config.vdaf_verify_key,
agg_param,
&report_share.public_share,
input_share.as_ref(),
),
VdafConfig::Pine(pine) => pine.prep_init(
&task_config.vdaf_verify_key,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
};

Ok(match res {
Ok((prep_state, prep_share)) => Self::Ready {
metadata: report_share.report_metadata,
public_share: report_share.public_share,
peer_prep_share: peer_prep_share.map(|p| p.to_vec()),
prep_share,
prep_state,
},
Err(e) => {
tracing::warn!(error = ?e, "rejecting report");
reject!(VdafPrepError);
}
})
}

pub(crate) fn metadata(&self) -> &ReportMetadata {
match self {
Self::Ready { metadata, .. } | Self::Rejected { metadata, .. } => metadata,
}
}
}

pub(crate) enum ReportProcessedStatus {
/// The report should be marked as aggregated. However it has already been committed to
/// storage, so don't do so again.
Expand Down Expand Up @@ -343,25 +126,23 @@ impl DapTaskConfig {

let [leader_share, helper_share] = report.encrypted_input_shares;

let initialized_report = InitializedReport::new(
let initialized_report = InitializedReport::from_client(
&decrypter,
valid_report_time_range.clone(),
DapAggregatorRole::Leader,
task_id,
self,
ReportShare {
report_metadata: report.report_metadata,
public_share: report.public_share,
encrypted_input_share: leader_share,
},
None,
agg_param,
)?;
match initialized_report {
InitializedReport::Ready {
metadata,
public_share,
peer_prep_share: _,
peer_prep_share: (),
prep_share,
prep_state,
} => {
Expand Down Expand Up @@ -454,7 +235,7 @@ impl DapTaskConfig {
task_id: &TaskId,
agg_job_init_req: AggregationJobInitReq,
replay_protection: ReplayProtection,
) -> Result<Vec<InitializedReport>, DapError>
) -> Result<Vec<InitializedReport<WithPeerPrepShare>>, DapError>
where
H: HpkeDecrypter + Sync,
{
Expand All @@ -478,14 +259,13 @@ impl DapTaskConfig {
.prep_inits
.into_par_iter()
.map(|prep_init| {
InitializedReport::new(
InitializedReport::from_leader(
decrypter,
valid_report_time_range.clone(),
DapAggregatorRole::Helper,
task_id,
self,
prep_init.report_share,
Some(prep_init.payload),
prep_init.payload,
&agg_param,
)
})
Expand All @@ -499,7 +279,7 @@ impl DapTaskConfig {
&self,
report_status: &HashMap<ReportId, ReportProcessedStatus>,
part_batch_sel: &PartialBatchSelector,
initialized_reports: &[InitializedReport],
initialized_reports: &[InitializedReport<WithPeerPrepShare>],
) -> Result<(DapAggregateSpan<DapAggregateShare>, AggregationJobResp), DapError> {
let num_reports = initialized_reports.len();
let mut agg_span = DapAggregateSpan::default();
Expand All @@ -513,7 +293,7 @@ impl DapTaskConfig {
InitializedReport::Ready {
metadata,
public_share: _,
peer_prep_share: Some(leader_prep_share),
peer_prep_share: leader_prep_share,
prep_share: helper_prep_share,
prep_state: helper_prep_state,
} => {
Expand Down Expand Up @@ -581,11 +361,6 @@ impl DapTaskConfig {
}
}

InitializedReport::Ready {
peer_prep_share: None,
..
} => return Err(fatal_error!(err = "expected leader prep share, got none")),

InitializedReport::Rejected {
metadata: _,
failure,
Expand Down Expand Up @@ -769,21 +544,3 @@ fn produce_encrypted_agg_share(

hpke_config.encrypt(info, &agg_share_data)
}

/// checks if an iterator has no duplicate items, returns the ok if there are no dups or an error
/// with the first offending item.
fn no_duplicates<I>(iterator: I) -> Result<(), I::Item>
where
I: Iterator,
I::Item: Eq + std::hash::Hash + Copy,
{
let (lower, upper) = iterator.size_hint();
let mut seen = HashSet::with_capacity(upper.unwrap_or(lower));

for item in iterator {
if !seen.insert(item) {
return Err(item);
}
}
Ok(())
}
Loading

0 comments on commit 69970f0

Please sign in to comment.