Skip to content

Commit

Permalink
Refactor Prio3 dispatch logic, remove unneeded variants
Browse files Browse the repository at this point in the history
Move sharding, preparation, and unsharding dispatchers to `Prio3Config`.
Not all variants are supported by all DAP versions, so pass the version
to each dispatcher in order to resolve DAP and Prio3 version
compatibility. (Note Prio3 is not API-compatible across versions.)

The following variants are supported in DAP-13:

* Prio3Count
* Prio3Sum (to be added)
* Prio3SumVec
* Prio3Histogram

The following variants are supported in DAP-09:

* Prio3SumVecField64MultiproofHmacSha256Aes128
  • Loading branch information
cjpatton committed Dec 19, 2024
1 parent 11c1754 commit 83fde97
Show file tree
Hide file tree
Showing 14 changed files with 722 additions and 1,262 deletions.
25 changes: 14 additions & 11 deletions crates/daphne/benches/aggregation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,20 @@ fn consume_reports_vary_vdaf_dimension(c: &mut Criterion) {
let mut test = AggregationJobTest::new(
&VdafConfig::Prio2 { dimension: 0 },
HpkeKemId::P256HkdfSha256,
DapVersion::Latest,
DapVersion::Draft09,
);
test.disable_replay_protection();

let mut g = c.benchmark_group(function!());
for vdaf_length in vdaf_lengths {
let vdaf = VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 {
bits: 1,
length: vdaf_length,
chunk_length: 320,
num_proofs: 2,
});
let vdaf = VdafConfig::Prio3(
Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 {
bits: 1,
length: vdaf_length,
chunk_length: 320,
num_proofs: 2,
},
);
test.change_vdaf(vdaf);
let reports = test
.produce_repeated_reports(vdaf.gen_measurement().unwrap())
Expand All @@ -66,15 +68,16 @@ fn consume_reports_vary_vdaf_dimension(c: &mut Criterion) {
}

fn consume_reports_vary_num_reports(c: &mut Criterion) {
const VDAF: VdafConfig =
VdafConfig::Prio3Draft09(Prio3Config::SumVecField64MultiproofHmacSha256Aes128 {
const VDAF: VdafConfig = VdafConfig::Prio3(
Prio3Config::Draft09SumVecField64MultiproofHmacSha256Aes128 {
bits: 1,
length: 1000,
chunk_length: 320,
num_proofs: 2,
});
},
);

let mut test = AggregationJobTest::new(&VDAF, HpkeKemId::P256HkdfSha256, DapVersion::Latest);
let mut test = AggregationJobTest::new(&VDAF, HpkeKemId::P256HkdfSha256, DapVersion::Draft09);
test.disable_replay_protection();

let mut g = c.benchmark_group(function!());
Expand Down
27 changes: 8 additions & 19 deletions crates/daphne/src/protocol/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ use crate::{
protocol::{decode_ping_pong_framed, PingPongMessageType},
vdaf::{
prio2::{prio2_prep_finish, prio2_prep_finish_from_shares},
prio3::{prio3_prep_finish, prio3_prep_finish_from_shares},
prio3_draft09::{prio3_draft09_prep_finish, prio3_draft09_prep_finish_from_shares},
VdafError,
},
AggregationJobReportState, DapAggregateShare, DapAggregateSpan, DapAggregationJobState,
Expand Down Expand Up @@ -286,6 +284,7 @@ impl DapTaskConfig {
report_status: &HashMap<ReportId, ReportProcessedStatus>,
part_batch_sel: &PartialBatchSelector,
initialized_reports: &[InitializedReport<WithPeerPrepShare>],
version: DapVersion,
) -> Result<(DapAggregateSpan<DapAggregateShare>, AggregationJobResp), DapError> {
let num_reports = initialized_reports.len();
let mut agg_span = DapAggregateSpan::default();
Expand All @@ -304,23 +303,15 @@ impl DapTaskConfig {
prep_state: helper_prep_state,
} => {
let res = match &self.vdaf {
VdafConfig::Prio3Draft09(prio3_config) => {
prio3_draft09_prep_finish_from_shares(
prio3_config,
VdafConfig::Prio3(prio3_config) => prio3_config
.prep_finish_from_shares(
version,
1,
task_id,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
)
}
VdafConfig::Prio3(prio3_config) => prio3_prep_finish_from_shares(
prio3_config,
1,
task_id,
helper_prep_state.clone(),
helper_prep_share.clone(),
leader_prep_share,
),
),
VdafConfig::Prio2 { dimension } => prio2_prep_finish_from_shares(
*dimension,
helper_prep_state.clone(),
Expand Down Expand Up @@ -406,6 +397,7 @@ impl DapTaskConfig {
state: DapAggregationJobState,
agg_job_resp: AggregationJobResp,
metrics: &dyn DaphneMetrics,
version: DapVersion,
) -> Result<DapAggregateSpan<DapAggregateShare>, DapError> {
if agg_job_resp.transitions.len() != state.seq.len() {
return Err(DapAbort::InvalidMessage {
Expand Down Expand Up @@ -459,11 +451,8 @@ impl DapTaskConfig {
};

let res = match &self.vdaf {
VdafConfig::Prio3Draft09(prio3_config) => {
prio3_draft09_prep_finish(prio3_config, leader.prep_state, prep_msg)
}
VdafConfig::Prio3(prio3_config) => {
prio3_prep_finish(prio3_config, leader.prep_state, prep_msg, *task_id)
prio3_config.prep_finish(leader.prep_state, prep_msg, *task_id, version)
}
VdafConfig::Prio2 { dimension } => {
prio2_prep_finish(*dimension, leader.prep_state, prep_msg)
Expand Down
10 changes: 4 additions & 6 deletions crates/daphne/src/protocol/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
constants::DapAggregatorRole,
hpke::{info_and_aad, HpkeConfig},
messages::{Extension, PlaintextInputShare, Report, ReportId, ReportMetadata, TaskId, Time},
vdaf::{prio2::prio2_shard, prio3::prio3_shard, prio3_draft09::prio3_draft09_shard, VdafError},
vdaf::{prio2::prio2_shard, VdafError},
DapError, DapMeasurement, DapVersion, VdafConfig,
};
use prio::codec::ParameterizedEncode;
Expand Down Expand Up @@ -45,7 +45,7 @@ impl VdafConfig {
let mut rng = thread_rng();
let report_id = ReportId(rng.gen());
let (public_share, input_shares) = self
.produce_input_shares(measurement, &report_id.0, task_id)
.produce_input_shares(measurement, &report_id.0, task_id, version)
.map_err(DapError::from_vdaf)?;
Self::produce_report_with_extensions_for_shares(
public_share,
Expand Down Expand Up @@ -122,13 +122,11 @@ impl VdafConfig {
measurement: DapMeasurement,
nonce: &[u8; 16],
task_id: &TaskId,
version: DapVersion,
) -> Result<(Vec<u8>, [Vec<u8>; 2]), VdafError> {
match self {
Self::Prio3Draft09(prio3_config) => {
Ok(prio3_draft09_shard(prio3_config, measurement, nonce)?)
}
Self::Prio3(prio3_config) => {
Ok(prio3_shard(prio3_config, measurement, nonce, *task_id)?)
Ok(prio3_config.shard(version, measurement, nonce, *task_id)?)
}
Self::Prio2 { dimension } => Ok(prio2_shard(*dimension, measurement, nonce)?),
#[cfg(feature = "experimental")]
Expand Down
7 changes: 3 additions & 4 deletions crates/daphne/src/protocol/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
fatal_error,
hpke::{info_and_aad, HpkeDecrypter},
messages::{BatchSelector, HpkeCiphertext, TaskId},
vdaf::{prio2::prio2_unshard, prio3::prio3_unshard, prio3_draft09::prio3_draft09_unshard},
vdaf::prio2::prio2_unshard,
DapAggregateResult, DapAggregationParam, DapError, DapVersion, VdafConfig,
};

Expand Down Expand Up @@ -73,10 +73,9 @@ impl VdafConfig {

let num_measurements = usize::try_from(report_count).unwrap();
match self {
Self::Prio3Draft09(prio3_config) => {
prio3_draft09_unshard(prio3_config, num_measurements, agg_shares)
Self::Prio3(prio3_config) => {
prio3_config.unshard(version, num_measurements, agg_shares)
}
Self::Prio3(prio3_config) => prio3_unshard(prio3_config, num_measurements, agg_shares),
Self::Prio2 { dimension } => prio2_unshard(*dimension, num_measurements, agg_shares),
#[cfg(feature = "experimental")]
Self::Mastic {
Expand Down
Loading

0 comments on commit 83fde97

Please sign in to comment.