Skip to content

Commit

Permalink
Unify VDAF dispatching logic
Browse files Browse the repository at this point in the history
Add dispatchers for sharding, preparation, and unsharding to
`VdafConfig`. Pass the DAP version to the dispatcher so that we can
properly resolve supported VDAFs.

Remove support for the following, unspecified combinations:

* DAP-13/Pine32HmacSha256Aes128
* DAP-13/Pine64HmacSha256Aes128
* DAP-09/Mastic

Technically, since Mastic is still a shim at the moment and not a
functional VDAF, we could support it in DAP-09 as well. When the
implementation is complete
(divviup/libprio-rs#947), it will only be
VDAF-13-compatible and thus only supported in DAP-13.
  • Loading branch information
cjpatton committed Dec 20, 2024
1 parent 4ee168f commit 5ac2daf
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 179 deletions.
1 change: 0 additions & 1 deletion crates/dapf/src/acceptance/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,6 @@ impl Test {
agg_job_state,
agg_job_resp,
self.metrics(),
task_config.version,
)?;

let aggregated_report_count = agg_share_span
Expand Down
66 changes: 13 additions & 53 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ use super::{
no_duplicates,
report_init::{InitializedReport, WithPeerPrepShare},
};
#[cfg(feature = "experimental")]
use crate::vdaf::mastic::{mastic_prep_finish, mastic_prep_finish_from_shares};
use crate::{
constants::DapAggregatorRole,
error::DapAbort,
Expand All @@ -19,12 +17,9 @@ use crate::{
},
metrics::{DaphneMetrics, ReportStatus},
protocol::{decode_ping_pong_framed, PingPongMessageType},
vdaf::{
prio2::{prio2_prep_finish, prio2_prep_finish_from_shares},
VdafError,
},
vdaf::VdafError,
AggregationJobReportState, DapAggregateShare, DapAggregateSpan, DapAggregationJobState,
DapAggregationParam, DapError, DapTaskConfig, DapVersion, VdafConfig,
DapAggregationParam, DapError, DapTaskConfig, DapVersion,
};

use prio::codec::{encode_u32_items, Encode, ParameterizedDecode, ParameterizedEncode};
Expand Down Expand Up @@ -284,7 +279,6 @@ impl DapTaskConfig {
report_status: &HashMap<ReportId, ReportProcessedStatus>,
part_batch_sel: &PartialBatchSelector,
initialized_reports: &[InitializedReport<WithPeerPrepShare>],
version: DapVersion,
) -> Result<(DapAggregateSpan<DapAggregateShare>, AggregationJobResp), DapError> {
let num_reports = initialized_reports.len();
let mut agg_span = DapAggregateSpan::default();
Expand All @@ -302,39 +296,14 @@ impl DapTaskConfig {
prep_share: helper_prep_share,
prep_state: helper_prep_state,
} => {
let res = match &self.vdaf {
VdafConfig::Prio3(prio3_config) => prio3_config
.prep_finish_from_shares(
version,
1,
task_id,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
),
VdafConfig::Prio2 { dimension } => prio2_prep_finish_from_shares(
*dimension,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
),
#[cfg(feature = "experimental")]
VdafConfig::Mastic {
input_size: _,
weight_config,
} => mastic_prep_finish_from_shares(
*weight_config,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
),
VdafConfig::Pine(pine) => pine.prep_finish_from_shares(
1,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
),
};
let res = self.vdaf.prep_finish_from_shares(
self.version,
1,
task_id,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
);

match res {
Ok((data, prep_msg)) => {
Expand Down Expand Up @@ -397,7 +366,6 @@ impl DapTaskConfig {
state: DapAggregationJobState,
agg_job_resp: AggregationJobResp,
metrics: &dyn DaphneMetrics,
version: DapVersion,
) -> Result<DapAggregateSpan<DapAggregateShare>, DapError> {
if agg_job_resp.transitions.len() != state.seq.len() {
return Err(DapAbort::InvalidMessage {
Expand Down Expand Up @@ -450,17 +418,9 @@ impl DapTaskConfig {
}
};

let res = match &self.vdaf {
VdafConfig::Prio3(prio3_config) => {
prio3_config.prep_finish(leader.prep_state, prep_msg, *task_id, version)
}
VdafConfig::Prio2 { dimension } => {
prio2_prep_finish(*dimension, leader.prep_state, prep_msg)
}
#[cfg(feature = "experimental")]
VdafConfig::Mastic { .. } => mastic_prep_finish(leader.prep_state, prep_msg),
VdafConfig::Pine(pine) => pine.prep_finish(leader.prep_state, prep_msg),
};
let res = self
.vdaf
.prep_finish(leader.prep_state, prep_msg, *task_id, self.version);

match res {
Ok(data) => {
Expand Down
27 changes: 1 addition & 26 deletions crates/daphne/src/protocol/client.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#[cfg(feature = "experimental")]
use crate::vdaf::mastic::mastic_shard;
use crate::{
constants::DapAggregatorRole,
hpke::{info_and_aad, HpkeConfig},
messages::{Extension, PlaintextInputShare, Report, ReportId, ReportMetadata, TaskId, Time},
vdaf::{prio2::prio2_shard, VdafError},
DapError, DapMeasurement, DapVersion, VdafConfig,
};
use prio::codec::ParameterizedEncode;
Expand Down Expand Up @@ -45,7 +42,7 @@ impl VdafConfig {
let mut rng = thread_rng();
let report_id = ReportId(rng.gen());
let (public_share, input_shares) = self
.produce_input_shares(measurement, &report_id.0, task_id, version)
.shard(measurement, &report_id.0, *task_id, version)
.map_err(DapError::from_vdaf)?;
Self::produce_report_with_extensions_for_shares(
public_share,
Expand Down Expand Up @@ -116,28 +113,6 @@ impl VdafConfig {
})
}

/// Generate shares for a measurement.
pub(crate) fn produce_input_shares(
&self,
measurement: DapMeasurement,
nonce: &[u8; 16],
task_id: &TaskId,
version: DapVersion,
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
match self {
Self::Prio3(prio3_config) => {
Ok(prio3_config.shard(version, measurement, nonce, *task_id)?)
}
Self::Prio2 { dimension } => Ok(prio2_shard(*dimension, measurement, nonce)?),
#[cfg(feature = "experimental")]
VdafConfig::Mastic {
input_size,
weight_config,
} => Ok(mastic_shard(*input_size, *weight_config, measurement)?),
VdafConfig::Pine(pine) => Ok(pine.shard(measurement, nonce)?),
}
}

/// Generate a report for a measurement. This method is run by the Client.
///
/// # Inputs
Expand Down
18 changes: 2 additions & 16 deletions crates/daphne/src/protocol/collector.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#[cfg(feature = "experimental")]
use crate::vdaf::mastic::mastic_unshard;
use crate::{
constants::DapAggregatorRole,
fatal_error,
hpke::{info_and_aad, HpkeDecrypter},
messages::{BatchSelector, HpkeCiphertext, TaskId},
vdaf::prio2::prio2_unshard,
DapAggregateResult, DapAggregationParam, DapError, DapVersion, VdafConfig,
};

Expand Down Expand Up @@ -72,18 +69,7 @@ impl VdafConfig {
}

let num_measurements = usize::try_from(report_count).unwrap();
match self {
Self::Prio3(prio3_config) => {
prio3_config.unshard(version, num_measurements, agg_shares)
}
Self::Prio2 { dimension } => prio2_unshard(*dimension, num_measurements, agg_shares),
#[cfg(feature = "experimental")]
Self::Mastic {
input_size: _,
weight_config,
} => mastic_unshard(*weight_config, agg_param, agg_shares),
Self::Pine(pine) => pine.unshard(num_measurements, agg_shares),
}
.map_err(|e| fatal_error!(err = ?e, "failed to unshard agg_shares"))
self.unshard(version, agg_param, num_measurements, agg_shares)
.map_err(|e| fatal_error!(err = ?e, "failed to unshard agg_shares"))
}
}
12 changes: 6 additions & 6 deletions crates/daphne/src/protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,10 +841,10 @@ mod test {
let (invalid_public_share, mut invalid_input_shares) = self
.task_config
.vdaf
.produce_input_shares(
.shard(
measurement,
&report_id.0,
&self.task_id,
self.task_id,
self.task_config.version,
)
.unwrap();
Expand Down Expand Up @@ -872,10 +872,10 @@ mod test {
let (mut invalid_public_share, invalid_input_shares) = self
.task_config
.vdaf
.produce_input_shares(
.shard(
measurement,
&report_id.0,
&self.task_id,
self.task_id,
self.task_config.version,
)
.unwrap();
Expand Down Expand Up @@ -903,10 +903,10 @@ mod test {
let (invalid_public_share, mut invalid_input_shares) = self
.task_config
.vdaf
.produce_input_shares(
.shard(
measurement,
&report_id.0,
&self.task_id,
self.task_id,
self.task_config.version,
)
.unwrap();
Expand Down
55 changes: 11 additions & 44 deletions crates/daphne/src/protocol/report_init.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

#[cfg(feature = "experimental")]
use crate::vdaf::mastic::mastic_prep_init;
use crate::{
constants::DapAggregatorRole,
hpke::{info_and_aad, HpkeDecrypter},
messages::{
self, Extension, PlaintextInputShare, ReportError, ReportMetadata, ReportShare, TaskId,
},
protocol::{decode_ping_pong_framed, no_duplicates, PingPongMessageType},
vdaf::{prio2::prio2_prep_init, VdafConfig, VdafPrepShare, VdafPrepState},
vdaf::{VdafPrepShare, VdafPrepState},
DapAggregationParam, DapError, DapTaskConfig,
};
use prio::codec::{CodecError, ParameterizedDecode as _};
Expand Down Expand Up @@ -98,9 +96,6 @@ impl<P> InitializedReport<P> {
task_config: &DapTaskConfig,
report_share: ReportShare,
prep_init_payload: S,
// We need to use this variable for Mastic, which is currently fenced by the
// "experimental" feature.
#[cfg_attr(not(feature = "experimental"), expect(unused_variables))]
agg_param: &DapAggregationParam,
) -> Result<Self, DapError>
where
Expand Down Expand Up @@ -193,44 +188,16 @@ impl<P> InitializedReport<P> {
DapAggregatorRole::Leader => 0,
DapAggregatorRole::Helper => 1,
};
let res = match &task_config.vdaf {
VdafConfig::Prio3(prio3_config) => prio3_config.prep_init(
task_config.version,
&task_config.vdaf_verify_key,
*task_id,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
VdafConfig::Prio2 { dimension } => prio2_prep_init(
*dimension,
&task_config.vdaf_verify_key,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
#[cfg(feature = "experimental")]
VdafConfig::Mastic {
input_size,
weight_config,
} => mastic_prep_init(
*input_size,
*weight_config,
&task_config.vdaf_verify_key,
agg_param,
&report_share.public_share,
input_share.as_ref(),
),
VdafConfig::Pine(pine) => pine.prep_init(
&task_config.vdaf_verify_key,
agg_id,
&report_share.report_metadata.id.0,
&report_share.public_share,
&input_share,
),
};
let res = task_config.vdaf.prep_init(
task_config.version,
&task_config.vdaf_verify_key,
*task_id,
agg_id,
agg_param,
report_share.report_metadata.id.as_ref(),
&report_share.public_share,
&input_share,
);

match res {
Ok((prep_state, prep_share)) => Ok(InitializedReport::Ready {
Expand Down
1 change: 0 additions & 1 deletion crates/daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ async fn finish_agg_job_and_aggregate(
&report_status,
part_batch_sel,
initialized_reports,
task_config.version,
)?;

let put_shares_result = helper
Expand Down
9 changes: 2 additions & 7 deletions crates/daphne/src/roles/leader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,8 @@ async fn run_agg_job<A: DapLeader>(
.map_err(|e| DapAbort::from_codec_error(e, *task_id))?;

// Handle AggregationJobResp.
let agg_span = task_config.consume_agg_job_resp(
task_id,
agg_job_state,
agg_job_resp,
metrics,
task_config.version,
)?;
let agg_span =
task_config.consume_agg_job_resp(task_id, agg_job_state, agg_job_resp, metrics)?;

let out_shares_count = agg_span.report_count() as u64;
if out_shares_count == 0 {
Expand Down
10 changes: 1 addition & 9 deletions crates/daphne/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ impl AggregationJobTest {
self.replay_protection,
)
.unwrap(),
self.task_config.version,
)
.unwrap()
}
Expand All @@ -242,7 +241,6 @@ impl AggregationJobTest {
leader_state,
agg_job_resp,
&self.leader_metrics,
self.task_config.version,
)
.unwrap()
}
Expand All @@ -255,13 +253,7 @@ impl AggregationJobTest {
) -> DapError {
let metrics = &self.leader_metrics;
self.task_config
.consume_agg_job_resp(
&self.task_id,
leader_state,
agg_job_resp,
metrics,
self.task_config.version,
)
.consume_agg_job_resp(&self.task_id, leader_state, agg_job_resp, metrics)
.expect_err("consume_agg_job_resp() succeeded; expected failure")
}

Expand Down
Loading

0 comments on commit 5ac2daf

Please sign in to comment.