Skip to content

Commit

Permalink
Update libprio-rs: Prio3Sum w/ max_measurement rather than bits. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
branlwyd authored Dec 11, 2024
1 parent eabc563 commit c26487c
Show file tree
Hide file tree
Showing 21 changed files with 403 additions and 340 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pretty_assertions = "1.4.1"
# re-enable them
# TODO(#3436): switch to a released version of libprio, once there is a released version implementing VDAF-13
# prio = { version = "0.16.7", default-features = false, features = ["experimental"] }
prio = { git = "https://github.com/divviup/libprio-rs", rev = "c85e537682a7932edc0c44c80049df0429d6fa4c", default-features = false, features = ["experimental"] }
prio = { git = "https://github.com/divviup/libprio-rs", rev = "3c1aeb30c661d373566749a81589fc0a4045f89a", default-features = false, features = ["experimental"] }
prometheus = "0.13.4"
querystring = "1.1.0"
quickcheck = { version = "1.0.3", default-features = false }
Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -828,8 +828,8 @@ impl<C: Clock> TaskAggregator<C> {
VdafOps::Prio3Count(Arc::new(vdaf), verify_key)
}

VdafInstance::Prio3Sum { bits } => {
let vdaf = Prio3::new_sum(2, *bits)?;
VdafInstance::Prio3Sum { max_measurement } => {
let vdaf = Prio3::new_sum(2, u128::from(*max_measurement))?;
let verify_key = task.vdaf_verify_key()?;
VdafOps::Prio3Sum(Arc::new(vdaf), verify_key)
}
Expand Down
8 changes: 4 additions & 4 deletions aggregator/src/aggregator/aggregation_job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,8 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
.await
}

(task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits }) => {
let vdaf = Arc::new(Prio3::new_sum(2, *bits)?);
(task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { max_measurement }) => {
let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?);
self.create_aggregation_jobs_for_time_interval_task_no_param::<VERIFY_KEY_LENGTH, Prio3Sum>(task, vdaf)
.await
}
Expand Down Expand Up @@ -415,9 +415,9 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
task::BatchMode::LeaderSelected {
batch_time_window_size,
},
VdafInstance::Prio3Sum { bits },
VdafInstance::Prio3Sum { max_measurement },
) => {
let vdaf = Arc::new(Prio3::new_sum(2, *bits)?);
let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?);
let batch_time_window_size = *batch_time_window_size;
self.create_aggregation_jobs_for_leader_selected_task_no_param::<
VERIFY_KEY_LENGTH,
Expand Down
4 changes: 2 additions & 2 deletions aggregator/src/aggregator/aggregation_job_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,10 @@ where
.aggregated_report_share_dimension_histogram
.record(1, &[KeyValue::new("type", "Prio3Count")]),

Prio3Sum { bits } => metrics
Prio3Sum { max_measurement } => metrics
.aggregated_report_share_dimension_histogram
.record(
u64::try_from(*bits).unwrap_or(u64::MAX),
*max_measurement,
&[KeyValue::new("type", "Prio3Sum")],
),

Expand Down
30 changes: 20 additions & 10 deletions aggregator/src/binaries/janus_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1201,10 +1201,15 @@ mod tests {
.build()
.leader_view()
.unwrap(),
TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits: 64 })
.build()
.helper_view()
.unwrap(),
TaskBuilder::new(
BatchMode::TimeInterval,
VdafInstance::Prio3Sum {
max_measurement: 4096,
},
)
.build()
.helper_view()
.unwrap(),
]);

let written_tasks = run_provision_tasks_testcase(&ds, &tasks, false).await;
Expand Down Expand Up @@ -1254,10 +1259,15 @@ mod tests {
.build()
.leader_view()
.unwrap(),
TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits: 64 })
.build()
.leader_view()
.unwrap(),
TaskBuilder::new(
BatchMode::TimeInterval,
VdafInstance::Prio3Sum {
max_measurement: 4096,
},
)
.build()
.leader_view()
.unwrap(),
]);

let ephemeral_datastore = ephemeral_datastore().await;
Expand Down Expand Up @@ -1327,7 +1337,7 @@ mod tests {
- peer_aggregator_endpoint: https://helper
batch_mode: TimeInterval
vdaf: !Prio3Sum
bits: 2
max_measurement: 4096
role: Leader
vdaf_verify_key:
task_end: 9000000000
Expand All @@ -1350,7 +1360,7 @@ mod tests {
- peer_aggregator_endpoint: https://leader
batch_mode: TimeInterval
vdaf: !Prio3Sum
bits: 2
max_measurement: 4096
role: Helper
vdaf_verify_key:
task_end: 9000000000
Expand Down
14 changes: 12 additions & 2 deletions aggregator_core/src/datastore/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,18 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) {
},
Role::Helper,
),
(VdafInstance::Prio3Sum { bits: 64 }, Role::Helper),
(VdafInstance::Prio3Sum { bits: 32 }, Role::Helper),
(
VdafInstance::Prio3Sum {
max_measurement: 4096,
},
Role::Helper,
),
(
VdafInstance::Prio3Sum {
max_measurement: 4096,
},
Role::Helper,
),
(
VdafInstance::Prio3Histogram {
length: 4,
Expand Down
2 changes: 1 addition & 1 deletion collector/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1032,7 +1032,7 @@ mod tests {
async fn successful_collect_prio3_sum() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let vdaf = Prio3::new_sum(2, 8).unwrap();
let vdaf = Prio3::new_sum(2, 255).unwrap();
let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &144);
let collector = setup_collector(&mut server, vdaf);

Expand Down
26 changes: 14 additions & 12 deletions core/src/vdaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub enum VdafInstance {
/// A `Prio3` counter.
Prio3Count,
/// A `Prio3` sum.
Prio3Sum { bits: usize },
Prio3Sum { max_measurement: u64 },
/// A vector of `Prio3` sums.
Prio3SumVec {
bits: usize,
Expand Down Expand Up @@ -182,10 +182,8 @@ impl TryFrom<&taskprov::VdafConfig> for VdafInstance {
fn try_from(value: &taskprov::VdafConfig) -> Result<Self, Self::Error> {
match value {
taskprov::VdafConfig::Prio3Count => Ok(Self::Prio3Count),
taskprov::VdafConfig::Prio3Sum {
max_measurement: _max_measurement,
} => Ok(Self::Prio3Sum {
bits: 32, // TODO(#3436): plumb through max_measurement once it's available
taskprov::VdafConfig::Prio3Sum { max_measurement } => Ok(Self::Prio3Sum {
max_measurement: u64::from(*max_measurement),
}),
taskprov::VdafConfig::Prio3SumVec {
bits,
Expand Down Expand Up @@ -266,8 +264,8 @@ macro_rules! vdaf_dispatch_impl_base {
$body
}

::janus_core::vdaf::VdafInstance::Prio3Sum { bits } => {
let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *bits)?;
::janus_core::vdaf::VdafInstance::Prio3Sum { max_measurement } => {
let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *max_measurement as u128)?;
type $Vdaf = ::prio::vdaf::prio3::Prio3Sum;
const $VERIFY_KEY_LEN: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH;
type $DpStrategy = janus_core::dp::NoDifferentialPrivacy;
Expand Down Expand Up @@ -638,15 +636,17 @@ mod tests {
}],
);
assert_tokens(
&VdafInstance::Prio3Sum { bits: 64 },
&VdafInstance::Prio3Sum {
max_measurement: 4096,
},
&[
Token::StructVariant {
name: "VdafInstance",
variant: "Prio3Sum",
len: 1,
},
Token::Str("bits"),
Token::U64(64),
Token::Str("max_measurement"),
Token::U64(4096),
Token::StructVariantEnd,
],
);
Expand Down Expand Up @@ -854,9 +854,11 @@ length: 10"
serde_yaml::from_str(
"---
!Prio3Sum
bits: 12"
max_measurement: 4096"
),
Ok(VdafInstance::Prio3Sum { bits: 12 })
Ok(VdafInstance::Prio3Sum {
max_measurement: 4096
})
);
assert_matches!(
serde_yaml::from_str(
Expand Down
2 changes: 1 addition & 1 deletion docs/samples/tasks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# The task's VDAF. Each VDAF requires its own set of parameters.
vdaf: !Prio3Sum
bits: 16
max_measurement: 4096

# The DAP role of this Janus instance in this task. Either "Leader" or
# "Helper".
Expand Down
4 changes: 2 additions & 2 deletions integration_tests/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value {
VdafInstance::Prio3Count => json!({
"type": "Prio3Count"
}),
VdafInstance::Prio3Sum { bits } => json!({
VdafInstance::Prio3Sum { max_measurement } => json!({
"type": "Prio3Sum",
"bits": format!("{bits}"),
"max_measurement": format!("{max_measurement}"),
}),
VdafInstance::Prio3SumVec {
bits,
Expand Down
16 changes: 9 additions & 7 deletions integration_tests/tests/integration/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use prio::{
flp::gadgets::ParallelSumMultithreaded,
vdaf::{self, prio3::Prio3},
};
use rand::{random, thread_rng, Rng};
use rand::{random, seq::IteratorRandom as _, thread_rng, Rng};
use std::{iter, time::Duration as StdDuration};
use tokio::time::{self, sleep};
use url::Url;
Expand Down Expand Up @@ -395,12 +395,14 @@ pub async fn submit_measurements_and_verify_aggregate(
)
.await;
}
VdafInstance::Prio3Sum { bits } => {
let vdaf = Prio3::new_sum(2, *bits).unwrap();

let measurements = iter::repeat_with(|| (random::<u128>()) >> (128 - bits))
.take(total_measurements)
.collect::<Vec<_>>();
VdafInstance::Prio3Sum { max_measurement } => {
let max_measurement = u128::from(*max_measurement);
let vdaf = Prio3::new_sum(2, max_measurement).unwrap();

let measurements: Vec<_> =
iter::repeat_with(|| (0..=max_measurement).choose(&mut thread_rng()).unwrap())
.take(total_measurements)
.collect();
let aggregate_result = measurements.iter().sum();
let test_case = AggregationTestCase {
measurements,
Expand Down
9 changes: 7 additions & 2 deletions integration_tests/tests/integration/divviup_ts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,13 @@ async fn janus_divviup_ts_sum() {
install_test_trace_subscriber();
initialize_rustls();

run_divviup_ts_integration_test("janus_divviup_ts_sum", VdafInstance::Prio3Sum { bits: 8 })
.await;
run_divviup_ts_integration_test(
"janus_divviup_ts_sum",
VdafInstance::Prio3Sum {
max_measurement: 255,
},
)
.await;
}

#[tokio::test(flavor = "multi_thread")]
Expand Down
16 changes: 12 additions & 4 deletions integration_tests/tests/integration/in_cluster.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,11 @@ impl InClusterJanusPair {
helper_aggregator_id,
vdaf: match task.vdaf().to_owned() {
VdafInstance::Prio3Count => Vdaf::Count,
VdafInstance::Prio3Sum { bits } => Vdaf::Sum {
bits: bits.try_into().unwrap(),
VdafInstance::Prio3Sum {
max_measurement: _max_measurement,
} => Vdaf::Sum {
// TODO(#3436): once divviup_client is updated for DAP-13, plumb max_measurement through
bits: 64,
},
VdafInstance::Prio3SumVec {
bits,
Expand Down Expand Up @@ -564,8 +567,13 @@ async fn in_cluster_sum() {
initialize_rustls();

// Start port forwards and set up task.
let janus_pair =
InClusterJanusPair::new(VdafInstance::Prio3Sum { bits: 16 }, BatchMode::TimeInterval).await;
let janus_pair = InClusterJanusPair::new(
VdafInstance::Prio3Sum {
max_measurement: 65535,
},
BatchMode::TimeInterval,
)
.await;

// Run the behavioral test.
submit_measurements_and_verify_aggregate(
Expand Down
8 changes: 6 additions & 2 deletions integration_tests/tests/integration/janus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ async fn janus_janus_sum_16() {
// Start servers.
let janus_pair = JanusContainerPair::new(
TEST_NAME,
VdafInstance::Prio3Sum { bits: 16 },
VdafInstance::Prio3Sum {
max_measurement: 65535,
},
BatchMode::TimeInterval,
)
.await;
Expand All @@ -186,7 +188,9 @@ async fn janus_in_process_sum_16() {
// Start servers.
let janus_pair = JanusInProcessPair::new(TaskBuilder::new(
BatchMode::TimeInterval,
VdafInstance::Prio3Sum { bits: 16 },
VdafInstance::Prio3Sum {
max_measurement: 4096,
},
))
.await;

Expand Down
5 changes: 3 additions & 2 deletions interop_binaries/src/commands/janus_interop_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,10 @@ async fn handle_upload(
handle_upload_generic(http_client, vdaf, request, measurement != 0).await?;
}

VdafInstance::Prio3Sum { bits } => {
VdafInstance::Prio3Sum { max_measurement } => {
let measurement = parse_primitive_measurement::<u128>(request.measurement.clone())?;
let vdaf = Prio3::new_sum(2, bits).context("failed to construct Prio3Sum VDAF")?;
let vdaf = Prio3::new_sum(2, u128::from(max_measurement))
.context("failed to construct Prio3Sum VDAF")?;
handle_upload_generic(http_client, vdaf, request, measurement).await?;
}

Expand Down
Loading

0 comments on commit c26487c

Please sign in to comment.