From c26487c6cdd800001612cd8e98e238116a7cb700 Mon Sep 17 00:00:00 2001 From: Brandon Pitman Date: Wed, 11 Dec 2024 08:36:22 -0800 Subject: [PATCH] Update libprio-rs: Prio3Sum w/ `max_measurement` rather than `bits`. (#3565) --- Cargo.lock | 2 +- Cargo.toml | 2 +- aggregator/src/aggregator.rs | 4 +- .../src/aggregator/aggregation_job_creator.rs | 8 +- .../src/aggregator/aggregation_job_writer.rs | 4 +- aggregator/src/binaries/janus_cli.rs | 30 +- aggregator_core/src/datastore/tests.rs | 14 +- collector/src/lib.rs | 2 +- core/src/vdaf.rs | 26 +- docs/samples/tasks.yaml | 2 +- integration_tests/src/client.rs | 4 +- integration_tests/tests/integration/common.rs | 16 +- .../tests/integration/divviup_ts.rs | 9 +- .../tests/integration/in_cluster.rs | 16 +- integration_tests/tests/integration/janus.rs | 8 +- .../src/commands/janus_interop_client.rs | 5 +- .../src/commands/janus_interop_collector.rs | 543 +++++++++--------- interop_binaries/src/lib.rs | 10 +- interop_binaries/tests/end_to_end.rs | 2 +- tools/src/bin/collect.rs | 31 +- tools/tests/cmd/collect.trycmd | 5 +- 21 files changed, 403 insertions(+), 340 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5480ade20..ce6395c9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4252,7 +4252,7 @@ dependencies = [ [[package]] name = "prio" version = "0.17.0-alpha.0" -source = "git+https://github.com/divviup/libprio-rs?rev=c85e537682a7932edc0c44c80049df0429d6fa4c#c85e537682a7932edc0c44c80049df0429d6fa4c" +source = "git+https://github.com/divviup/libprio-rs?rev=3c1aeb30c661d373566749a81589fc0a4045f89a#3c1aeb30c661d373566749a81589fc0a4045f89a" dependencies = [ "aes 0.8.4", "bitvec", diff --git a/Cargo.toml b/Cargo.toml index ba0506d0b..5c85be6a3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,7 +77,7 @@ pretty_assertions = "1.4.1" # re-enable them # TODO(#3436): switch to a released version of libprio, once there is a released version implementing VDAF-13 # prio = { version = "0.16.7", default-features = false, features = ["experimental"] } -prio = { git = "https://github.com/divviup/libprio-rs", rev = "c85e537682a7932edc0c44c80049df0429d6fa4c", default-features = false, features = ["experimental"] } +prio = { git = "https://github.com/divviup/libprio-rs", rev = "3c1aeb30c661d373566749a81589fc0a4045f89a", default-features = false, features = ["experimental"] } prometheus = "0.13.4" querystring = "1.1.0" quickcheck = { version = "1.0.3", default-features = false } diff --git a/aggregator/src/aggregator.rs b/aggregator/src/aggregator.rs index 2f671194e..5ea1d1ffc 100644 --- a/aggregator/src/aggregator.rs +++ b/aggregator/src/aggregator.rs @@ -828,8 +828,8 @@ impl TaskAggregator { VdafOps::Prio3Count(Arc::new(vdaf), verify_key) } - VdafInstance::Prio3Sum { bits } => { - let vdaf = Prio3::new_sum(2, *bits)?; + VdafInstance::Prio3Sum { max_measurement } => { + let vdaf = Prio3::new_sum(2, u128::from(*max_measurement))?; let verify_key = task.vdaf_verify_key()?; VdafOps::Prio3Sum(Arc::new(vdaf), verify_key) } diff --git a/aggregator/src/aggregator/aggregation_job_creator.rs b/aggregator/src/aggregator/aggregation_job_creator.rs index da5cc97b2..b504251f1 100644 --- a/aggregator/src/aggregator/aggregation_job_creator.rs +++ b/aggregator/src/aggregator/aggregation_job_creator.rs @@ -311,8 +311,8 @@ impl AggregationJobCreator { .await } - (task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits }) => { - let vdaf = Arc::new(Prio3::new_sum(2, *bits)?); + (task::BatchMode::TimeInterval, VdafInstance::Prio3Sum { max_measurement }) => { + let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?); self.create_aggregation_jobs_for_time_interval_task_no_param::(task, vdaf) .await } @@ -415,9 +415,9 @@ impl AggregationJobCreator { task::BatchMode::LeaderSelected { batch_time_window_size, }, - VdafInstance::Prio3Sum { bits }, + VdafInstance::Prio3Sum { max_measurement }, ) => { - let vdaf = Arc::new(Prio3::new_sum(2, *bits)?); + let vdaf = Arc::new(Prio3::new_sum(2, u128::from(*max_measurement))?); let batch_time_window_size = *batch_time_window_size; self.create_aggregation_jobs_for_leader_selected_task_no_param::< VERIFY_KEY_LENGTH, diff --git a/aggregator/src/aggregator/aggregation_job_writer.rs b/aggregator/src/aggregator/aggregation_job_writer.rs index f53198cfc..0c0201542 100644 --- a/aggregator/src/aggregator/aggregation_job_writer.rs +++ b/aggregator/src/aggregator/aggregation_job_writer.rs @@ -705,10 +705,10 @@ where .aggregated_report_share_dimension_histogram .record(1, &[KeyValue::new("type", "Prio3Count")]), - Prio3Sum { bits } => metrics + Prio3Sum { max_measurement } => metrics .aggregated_report_share_dimension_histogram .record( - u64::try_from(*bits).unwrap_or(u64::MAX), + *max_measurement, &[KeyValue::new("type", "Prio3Sum")], ), diff --git a/aggregator/src/binaries/janus_cli.rs b/aggregator/src/binaries/janus_cli.rs index ac6a9df94..d4bccc9fd 100644 --- a/aggregator/src/binaries/janus_cli.rs +++ b/aggregator/src/binaries/janus_cli.rs @@ -1201,10 +1201,15 @@ mod tests { .build() .leader_view() .unwrap(), - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits: 64 }) - .build() - .helper_view() - .unwrap(), + TaskBuilder::new( + BatchMode::TimeInterval, + VdafInstance::Prio3Sum { + max_measurement: 4096, + }, + ) + .build() + .helper_view() + .unwrap(), ]); let written_tasks = run_provision_tasks_testcase(&ds, &tasks, false).await; @@ -1254,10 +1259,15 @@ mod tests { .build() .leader_view() .unwrap(), - TaskBuilder::new(BatchMode::TimeInterval, VdafInstance::Prio3Sum { bits: 64 }) - .build() - .leader_view() - .unwrap(), + TaskBuilder::new( + BatchMode::TimeInterval, + VdafInstance::Prio3Sum { + max_measurement: 4096, + }, + ) + .build() + .leader_view() + .unwrap(), ]); let ephemeral_datastore = ephemeral_datastore().await; @@ -1327,7 +1337,7 @@ mod tests { - peer_aggregator_endpoint: https://helper batch_mode: TimeInterval vdaf: !Prio3Sum - bits: 2 + max_measurement: 4096 role: Leader vdaf_verify_key: task_end: 9000000000 @@ -1350,7 +1360,7 @@ mod tests { - peer_aggregator_endpoint: https://leader batch_mode: TimeInterval vdaf: !Prio3Sum - bits: 2 + max_measurement: 4096 role: Helper vdaf_verify_key: task_end: 9000000000 diff --git a/aggregator_core/src/datastore/tests.rs b/aggregator_core/src/datastore/tests.rs index b4333a5bb..3f9a1a2f5 100644 --- a/aggregator_core/src/datastore/tests.rs +++ b/aggregator_core/src/datastore/tests.rs @@ -176,8 +176,18 @@ async fn roundtrip_task(ephemeral_datastore: EphemeralDatastore) { }, Role::Helper, ), - (VdafInstance::Prio3Sum { bits: 64 }, Role::Helper), - (VdafInstance::Prio3Sum { bits: 32 }, Role::Helper), + ( + VdafInstance::Prio3Sum { + max_measurement: 4096, + }, + Role::Helper, + ), + ( + VdafInstance::Prio3Sum { + max_measurement: 4096, + }, + Role::Helper, + ), ( VdafInstance::Prio3Histogram { length: 4, diff --git a/collector/src/lib.rs b/collector/src/lib.rs index c6606823d..f343fc97a 100644 --- a/collector/src/lib.rs +++ b/collector/src/lib.rs @@ -1032,7 +1032,7 @@ mod tests { async fn successful_collect_prio3_sum() { install_test_trace_subscriber(); let mut server = mockito::Server::new_async().await; - let vdaf = Prio3::new_sum(2, 8).unwrap(); + let vdaf = Prio3::new_sum(2, 255).unwrap(); let transcript = run_vdaf(&vdaf, &random(), &random(), &(), &random(), &144); let collector = setup_collector(&mut server, vdaf); diff --git a/core/src/vdaf.rs b/core/src/vdaf.rs index 346472e34..34ea533bb 100644 --- a/core/src/vdaf.rs +++ b/core/src/vdaf.rs @@ -109,7 +109,7 @@ pub enum VdafInstance { /// A `Prio3` counter. Prio3Count, /// A `Prio3` sum. - Prio3Sum { bits: usize }, + Prio3Sum { max_measurement: u64 }, /// A vector of `Prio3` sums. Prio3SumVec { bits: usize, @@ -182,10 +182,8 @@ impl TryFrom<&taskprov::VdafConfig> for VdafInstance { fn try_from(value: &taskprov::VdafConfig) -> Result { match value { taskprov::VdafConfig::Prio3Count => Ok(Self::Prio3Count), - taskprov::VdafConfig::Prio3Sum { - max_measurement: _max_measurement, - } => Ok(Self::Prio3Sum { - bits: 32, // TODO(#3436): plumb through max_measurement once it's available + taskprov::VdafConfig::Prio3Sum { max_measurement } => Ok(Self::Prio3Sum { + max_measurement: u64::from(*max_measurement), }), taskprov::VdafConfig::Prio3SumVec { bits, @@ -266,8 +264,8 @@ macro_rules! vdaf_dispatch_impl_base { $body } - ::janus_core::vdaf::VdafInstance::Prio3Sum { bits } => { - let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *bits)?; + ::janus_core::vdaf::VdafInstance::Prio3Sum { max_measurement } => { + let $vdaf = ::prio::vdaf::prio3::Prio3::new_sum(2, *max_measurement as u128)?; type $Vdaf = ::prio::vdaf::prio3::Prio3Sum; const $VERIFY_KEY_LEN: usize = ::janus_core::vdaf::VERIFY_KEY_LENGTH; type $DpStrategy = janus_core::dp::NoDifferentialPrivacy; @@ -638,15 +636,17 @@ mod tests { }], ); assert_tokens( - &VdafInstance::Prio3Sum { bits: 64 }, + &VdafInstance::Prio3Sum { + max_measurement: 4096, + }, &[ Token::StructVariant { name: "VdafInstance", variant: "Prio3Sum", len: 1, }, - Token::Str("bits"), - Token::U64(64), + Token::Str("max_measurement"), + Token::U64(4096), Token::StructVariantEnd, ], ); @@ -854,9 +854,11 @@ length: 10" serde_yaml::from_str( "--- !Prio3Sum -bits: 12" +max_measurement: 4096" ), - Ok(VdafInstance::Prio3Sum { bits: 12 }) + Ok(VdafInstance::Prio3Sum { + max_measurement: 4096 + }) ); assert_matches!( serde_yaml::from_str( diff --git a/docs/samples/tasks.yaml b/docs/samples/tasks.yaml index d88462f98..33985a4f3 100644 --- a/docs/samples/tasks.yaml +++ b/docs/samples/tasks.yaml @@ -14,7 +14,7 @@ # The task's VDAF. Each VDAF requires its own set of parameters. vdaf: !Prio3Sum - bits: 16 + max_measurement: 4096 # The DAP role of this Janus instance in this task. Either "Leader" or # "Helper". diff --git a/integration_tests/src/client.rs b/integration_tests/src/client.rs index 660374de0..524a4f7fb 100644 --- a/integration_tests/src/client.rs +++ b/integration_tests/src/client.rs @@ -88,9 +88,9 @@ fn json_encode_vdaf(vdaf: &VdafInstance) -> Value { VdafInstance::Prio3Count => json!({ "type": "Prio3Count" }), - VdafInstance::Prio3Sum { bits } => json!({ + VdafInstance::Prio3Sum { max_measurement } => json!({ "type": "Prio3Sum", - "bits": format!("{bits}"), + "max_measurement": format!("{max_measurement}"), }), VdafInstance::Prio3SumVec { bits, diff --git a/integration_tests/tests/integration/common.rs b/integration_tests/tests/integration/common.rs index 41b1bdcbe..4e4f22aa4 100644 --- a/integration_tests/tests/integration/common.rs +++ b/integration_tests/tests/integration/common.rs @@ -20,7 +20,7 @@ use prio::{ flp::gadgets::ParallelSumMultithreaded, vdaf::{self, prio3::Prio3}, }; -use rand::{random, thread_rng, Rng}; +use rand::{random, seq::IteratorRandom as _, thread_rng, Rng}; use std::{iter, time::Duration as StdDuration}; use tokio::time::{self, sleep}; use url::Url; @@ -395,12 +395,14 @@ pub async fn submit_measurements_and_verify_aggregate( ) .await; } - VdafInstance::Prio3Sum { bits } => { - let vdaf = Prio3::new_sum(2, *bits).unwrap(); - - let measurements = iter::repeat_with(|| (random::()) >> (128 - bits)) - .take(total_measurements) - .collect::>(); + VdafInstance::Prio3Sum { max_measurement } => { + let max_measurement = u128::from(*max_measurement); + let vdaf = Prio3::new_sum(2, max_measurement).unwrap(); + + let measurements: Vec<_> = + iter::repeat_with(|| (0..=max_measurement).choose(&mut thread_rng()).unwrap()) + .take(total_measurements) + .collect(); let aggregate_result = measurements.iter().sum(); let test_case = AggregationTestCase { measurements, diff --git a/integration_tests/tests/integration/divviup_ts.rs b/integration_tests/tests/integration/divviup_ts.rs index 8c3bcae2b..89f477cd3 100644 --- a/integration_tests/tests/integration/divviup_ts.rs +++ b/integration_tests/tests/integration/divviup_ts.rs @@ -58,8 +58,13 @@ async fn janus_divviup_ts_sum() { install_test_trace_subscriber(); initialize_rustls(); - run_divviup_ts_integration_test("janus_divviup_ts_sum", VdafInstance::Prio3Sum { bits: 8 }) - .await; + run_divviup_ts_integration_test( + "janus_divviup_ts_sum", + VdafInstance::Prio3Sum { + max_measurement: 255, + }, + ) + .await; } #[tokio::test(flavor = "multi_thread")] diff --git a/integration_tests/tests/integration/in_cluster.rs b/integration_tests/tests/integration/in_cluster.rs index e1bdbb3e5..1e0d2148f 100644 --- a/integration_tests/tests/integration/in_cluster.rs +++ b/integration_tests/tests/integration/in_cluster.rs @@ -412,8 +412,11 @@ impl InClusterJanusPair { helper_aggregator_id, vdaf: match task.vdaf().to_owned() { VdafInstance::Prio3Count => Vdaf::Count, - VdafInstance::Prio3Sum { bits } => Vdaf::Sum { - bits: bits.try_into().unwrap(), + VdafInstance::Prio3Sum { + max_measurement: _max_measurement, + } => Vdaf::Sum { + // TODO(#3436): once divviup_client is updated for DAP-13, plumb max_measurement through + bits: 64, }, VdafInstance::Prio3SumVec { bits, @@ -564,8 +567,13 @@ async fn in_cluster_sum() { initialize_rustls(); // Start port forwards and set up task. - let janus_pair = - InClusterJanusPair::new(VdafInstance::Prio3Sum { bits: 16 }, BatchMode::TimeInterval).await; + let janus_pair = InClusterJanusPair::new( + VdafInstance::Prio3Sum { + max_measurement: 65535, + }, + BatchMode::TimeInterval, + ) + .await; // Run the behavioral test. submit_measurements_and_verify_aggregate( diff --git a/integration_tests/tests/integration/janus.rs b/integration_tests/tests/integration/janus.rs index 8ef5188cb..353b6ad30 100644 --- a/integration_tests/tests/integration/janus.rs +++ b/integration_tests/tests/integration/janus.rs @@ -162,7 +162,9 @@ async fn janus_janus_sum_16() { // Start servers. let janus_pair = JanusContainerPair::new( TEST_NAME, - VdafInstance::Prio3Sum { bits: 16 }, + VdafInstance::Prio3Sum { + max_measurement: 65535, + }, BatchMode::TimeInterval, ) .await; @@ -186,7 +188,9 @@ async fn janus_in_process_sum_16() { // Start servers. let janus_pair = JanusInProcessPair::new(TaskBuilder::new( BatchMode::TimeInterval, - VdafInstance::Prio3Sum { bits: 16 }, + VdafInstance::Prio3Sum { + max_measurement: 4096, + }, )) .await; diff --git a/interop_binaries/src/commands/janus_interop_client.rs b/interop_binaries/src/commands/janus_interop_client.rs index 5a1d29399..97f061a8d 100644 --- a/interop_binaries/src/commands/janus_interop_client.rs +++ b/interop_binaries/src/commands/janus_interop_client.rs @@ -123,9 +123,10 @@ async fn handle_upload( handle_upload_generic(http_client, vdaf, request, measurement != 0).await?; } - VdafInstance::Prio3Sum { bits } => { + VdafInstance::Prio3Sum { max_measurement } => { let measurement = parse_primitive_measurement::(request.measurement.clone())?; - let vdaf = Prio3::new_sum(2, bits).context("failed to construct Prio3Sum VDAF")?; + let vdaf = Prio3::new_sum(2, u128::from(max_measurement)) + .context("failed to construct Prio3Sum VDAF")?; handle_upload_generic(http_client, vdaf, request, measurement).await?; } diff --git a/interop_binaries/src/commands/janus_interop_collector.rs b/interop_binaries/src/commands/janus_interop_collector.rs index 941501a81..a7f913ddd 100644 --- a/interop_binaries/src/commands/janus_interop_collector.rs +++ b/interop_binaries/src/commands/janus_interop_collector.rs @@ -289,76 +289,137 @@ async fn handle_collection_start( }; let vdaf_instance = task_state.vdaf.clone().into(); - let task_handle = - match (query, vdaf_instance) { - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Count {}) => { - let vdaf = Prio3::new_count(2).context("failed to construct Prio3Count VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| AggregationResult::Number(NumberAsString((*result).into())), - ) - .await? - } + let task_handle = match (query, vdaf_instance) { + (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Count {}) => { + let vdaf = Prio3::new_count(2).context("failed to construct Prio3Count VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_time_interval(batch_interval), + vdaf, + &agg_param, + |_| None, + |result| AggregationResult::Number(NumberAsString((*result).into())), + ) + .await? + } - (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Sum { bits }) => { - let vdaf = Prio3::new_sum(2, bits).context("failed to construct Prio3Sum VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| AggregationResult::Number(NumberAsString(*result)), - ) - .await? - } + (ParsedQuery::TimeInterval(batch_interval), VdafInstance::Prio3Sum { max_measurement }) => { + let vdaf = Prio3::new_sum(2, u128::from(max_measurement)) + .context("failed to construct Prio3Sum VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_time_interval(batch_interval), + vdaf, + &agg_param, + |_| None, + |result| AggregationResult::Number(NumberAsString(*result)), + ) + .await? + } - ( - ParsedQuery::TimeInterval(batch_interval), - VdafInstance::Prio3SumVec { - bits, - length, - chunk_length, - dp_strategy: _, + ( + ParsedQuery::TimeInterval(batch_interval), + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length) + .context("failed to construct Prio3SumVec VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_time_interval(batch_interval), + vdaf, + &agg_param, + |_| None, + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::NumberVec(converted) }, - ) => { - let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length) - .context("failed to construct Prio3SumVec VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::NumberVec(converted) - }, - ) - .await? - } + ) + .await? + } - ( - ParsedQuery::TimeInterval(batch_interval), - VdafInstance::Prio3SumVecField64MultiproofHmacSha256Aes128 { - proofs, - bits, - length, - chunk_length, - dp_strategy: _, + ( + ParsedQuery::TimeInterval(batch_interval), + VdafInstance::Prio3SumVecField64MultiproofHmacSha256Aes128 { + proofs, + bits, + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128::>( + proofs, + bits, + length, + chunk_length, + ) + .context("failed to construct Prio3SumVecField64MultiproofHmacSha256Aes128 VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_time_interval(batch_interval), + vdaf, + &agg_param, + |_| None, + |result: &Vec| { + let converted = result + .iter() + .cloned() + .map(u128::from) + .map(NumberAsString) + .collect(); + AggregationResult::NumberVec(converted) }, - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128::< - ParallelSum<_, _>, - >(proofs, bits, length, chunk_length) - .context("failed to construct Prio3SumVecField64MultiproofHmacSha256Aes128 VDAF")?; + ) + .await? + } + + ( + ParsedQuery::TimeInterval(batch_interval), + VdafInstance::Prio3Histogram { + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = Prio3::new_histogram(2, length, chunk_length) + .context("failed to construct Prio3Histogram VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_time_interval(batch_interval), + vdaf, + &agg_param, + |_| None, + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::NumberVec(converted) + }, + ) + .await? + } + + #[cfg(feature = "fpvec_bounded_l2")] + ( + ParsedQuery::TimeInterval(batch_interval), + janus_core::vdaf::VdafInstance::Prio3FixedPointBoundedL2VecSum { + bitsize, + dp_strategy: _, + length, + }, + ) => match bitsize { + Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { + let vdaf: Prio3FixedPointBoundedL2VecSum> = + Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) + .context("failed to construct Prio3FixedPoint16BitBoundedL2VecSum VDAF")?; handle_collect_generic( http_client, task_state, @@ -366,29 +427,17 @@ async fn handle_collection_start( vdaf, &agg_param, |_| None, - |result: &Vec| { - let converted = result - .iter() - .cloned() - .map(u128::from) - .map(NumberAsString) - .collect(); - AggregationResult::NumberVec(converted) + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::FloatVec(converted) }, ) .await? } - - ( - ParsedQuery::TimeInterval(batch_interval), - VdafInstance::Prio3Histogram { - length, - chunk_length, - dp_strategy: _, - }, - ) => { - let vdaf = Prio3::new_histogram(2, length, chunk_length) - .context("failed to construct Prio3Histogram VDAF")?; + Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { + let vdaf: Prio3FixedPointBoundedL2VecSum> = + Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) + .context("failed to construct Prio3FixedPoint32BitBoundedL2VecSum VDAF")?; handle_collect_generic( http_client, task_state, @@ -398,126 +447,40 @@ async fn handle_collection_start( |_| None, |result| { let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::NumberVec(converted) + AggregationResult::FloatVec(converted) }, ) .await? } + }, - #[cfg(feature = "fpvec_bounded_l2")] - ( - ParsedQuery::TimeInterval(batch_interval), - janus_core::vdaf::VdafInstance::Prio3FixedPointBoundedL2VecSum { - bitsize, - dp_strategy: _, - length, - }, - ) => match bitsize { - Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { - let vdaf: Prio3FixedPointBoundedL2VecSum> = - Prio3::new_fixedpoint_boundedl2_vec_sum(2, length).context( - "failed to construct Prio3FixedPoint16BitBoundedL2VecSum VDAF", - )?; - handle_collect_generic( - http_client, - task_state, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::FloatVec(converted) - }, - ) - .await? - } - Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { - let vdaf: Prio3FixedPointBoundedL2VecSum> = - Prio3::new_fixedpoint_boundedl2_vec_sum(2, length).context( - "failed to construct Prio3FixedPoint32BitBoundedL2VecSum VDAF", - )?; - handle_collect_generic( - http_client, - task_state, - Query::new_time_interval(batch_interval), - vdaf, - &agg_param, - |_| None, - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::FloatVec(converted) - }, - ) - .await? - } - }, - - (ParsedQuery::LeaderSelected, VdafInstance::Prio3Count {}) => { - let vdaf = Prio3::new_count(2).context("failed to construct Prio3Count VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_leader_selected(), - vdaf, - &agg_param, - |selector| Some(*selector.batch_id()), - |result| AggregationResult::Number(NumberAsString((*result).into())), - ) - .await? - } + (ParsedQuery::LeaderSelected, VdafInstance::Prio3Count {}) => { + let vdaf = Prio3::new_count(2).context("failed to construct Prio3Count VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_leader_selected(), + vdaf, + &agg_param, + |selector| Some(*selector.batch_id()), + |result| AggregationResult::Number(NumberAsString((*result).into())), + ) + .await? + } - #[cfg(feature = "fpvec_bounded_l2")] - ( - ParsedQuery::LeaderSelected, - janus_core::vdaf::VdafInstance::Prio3FixedPointBoundedL2VecSum { - bitsize, - dp_strategy: _, - length, - }, - ) => match bitsize { - Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { - let vdaf: Prio3FixedPointBoundedL2VecSum> = - Prio3::new_fixedpoint_boundedl2_vec_sum(2, length).context( - "failed to construct Prio3FixedPoint16BitBoundedL2VecSum VDAF", - )?; - handle_collect_generic( - http_client, - task_state, - Query::new_leader_selected(), - vdaf, - &agg_param, - |selector| Some(*selector.batch_id()), - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::FloatVec(converted) - }, - ) - .await? - } - Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { - let vdaf: Prio3FixedPointBoundedL2VecSum> = - Prio3::new_fixedpoint_boundedl2_vec_sum(2, length).context( - "failed to construct Prio3FixedPoint32BitBoundedL2VecSum VDAF", - )?; - handle_collect_generic( - http_client, - task_state, - Query::new_leader_selected(), - vdaf, - &agg_param, - |selector| Some(*selector.batch_id()), - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::FloatVec(converted) - }, - ) - .await? - } + #[cfg(feature = "fpvec_bounded_l2")] + ( + ParsedQuery::LeaderSelected, + janus_core::vdaf::VdafInstance::Prio3FixedPointBoundedL2VecSum { + bitsize, + dp_strategy: _, + length, }, - - (ParsedQuery::LeaderSelected, VdafInstance::Prio3Sum { bits }) => { - let vdaf = Prio3::new_sum(2, bits).context("failed to construct Prio3Sum VDAF")?; + ) => match bitsize { + Prio3FixedPointBoundedL2VecSumBitSize::BitSize16 => { + let vdaf: Prio3FixedPointBoundedL2VecSum> = + Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) + .context("failed to construct Prio3FixedPoint16BitBoundedL2VecSum VDAF")?; handle_collect_generic( http_client, task_state, @@ -525,22 +488,17 @@ async fn handle_collection_start( vdaf, &agg_param, |selector| Some(*selector.batch_id()), - |result| AggregationResult::Number(NumberAsString(*result)), + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::FloatVec(converted) + }, ) .await? } - - ( - ParsedQuery::LeaderSelected, - VdafInstance::Prio3SumVec { - bits, - length, - chunk_length, - dp_strategy: _, - }, - ) => { - let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length) - .context("failed to construct Prio3SumVec VDAF")?; + Prio3FixedPointBoundedL2VecSumBitSize::BitSize32 => { + let vdaf: Prio3FixedPointBoundedL2VecSum> = + Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) + .context("failed to construct Prio3FixedPoint32BitBoundedL2VecSum VDAF")?; handle_collect_generic( http_client, task_state, @@ -550,75 +508,120 @@ async fn handle_collection_start( |selector| Some(*selector.batch_id()), |result| { let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::NumberVec(converted) + AggregationResult::FloatVec(converted) }, ) .await? } + }, + + (ParsedQuery::LeaderSelected, VdafInstance::Prio3Sum { max_measurement }) => { + let vdaf = Prio3::new_sum(2, u128::from(max_measurement)) + .context("failed to construct Prio3Sum VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_leader_selected(), + vdaf, + &agg_param, + |selector| Some(*selector.batch_id()), + |result| AggregationResult::Number(NumberAsString(*result)), + ) + .await? + } - ( - ParsedQuery::LeaderSelected, - VdafInstance::Prio3SumVecField64MultiproofHmacSha256Aes128 { - proofs, - bits, - length, - chunk_length, - dp_strategy: _, + ( + ParsedQuery::LeaderSelected, + VdafInstance::Prio3SumVec { + bits, + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length) + .context("failed to construct Prio3SumVec VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_leader_selected(), + vdaf, + &agg_param, + |selector| Some(*selector.batch_id()), + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::NumberVec(converted) }, - ) => { - let vdaf = new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128::< - ParallelSum<_, _>, - >(proofs, bits, length, chunk_length) - .context("failed to construct Prio3SumVecField64MultiproofHmacSha256Aes128 VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_leader_selected(), - vdaf, - &agg_param, - |selector| Some(*selector.batch_id()), - |result: &Vec| { - let converted = result - .iter() - .cloned() - .map(u128::from) - .map(NumberAsString) - .collect(); - AggregationResult::NumberVec(converted) - }, - ) - .await? - } + ) + .await? + } - ( - ParsedQuery::LeaderSelected, - VdafInstance::Prio3Histogram { - length, - chunk_length, - dp_strategy: _, + ( + ParsedQuery::LeaderSelected, + VdafInstance::Prio3SumVecField64MultiproofHmacSha256Aes128 { + proofs, + bits, + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = new_prio3_sum_vec_field64_multiproof_hmacsha256_aes128::>( + proofs, + bits, + length, + chunk_length, + ) + .context("failed to construct Prio3SumVecField64MultiproofHmacSha256Aes128 VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_leader_selected(), + vdaf, + &agg_param, + |selector| Some(*selector.batch_id()), + |result: &Vec| { + let converted = result + .iter() + .cloned() + .map(u128::from) + .map(NumberAsString) + .collect(); + AggregationResult::NumberVec(converted) }, - ) => { - let vdaf = Prio3::new_histogram(2, length, chunk_length) - .context("failed to construct Prio3Histogram VDAF")?; - handle_collect_generic( - http_client, - task_state, - Query::new_leader_selected(), - vdaf, - &agg_param, - |selector| Some(*selector.batch_id()), - |result| { - let converted = result.iter().cloned().map(NumberAsString).collect(); - AggregationResult::NumberVec(converted) - }, - ) - .await? - } + ) + .await? + } - (_, vdaf_instance) => { - panic!("Unsupported VDAF: {vdaf_instance:?}") - } - }; + ( + ParsedQuery::LeaderSelected, + VdafInstance::Prio3Histogram { + length, + chunk_length, + dp_strategy: _, + }, + ) => { + let vdaf = Prio3::new_histogram(2, length, chunk_length) + .context("failed to construct Prio3Histogram VDAF")?; + handle_collect_generic( + http_client, + task_state, + Query::new_leader_selected(), + vdaf, + &agg_param, + |selector| Some(*selector.batch_id()), + |result| { + let converted = result.iter().cloned().map(NumberAsString).collect(); + AggregationResult::NumberVec(converted) + }, + ) + .await? + } + + (_, vdaf_instance) => { + panic!("Unsupported VDAF: {vdaf_instance:?}") + } + }; let mut collection_jobs_guard = collection_jobs.lock().await; Ok(loop { diff --git a/interop_binaries/src/lib.rs b/interop_binaries/src/lib.rs index 87388afe0..04ac975a9 100644 --- a/interop_binaries/src/lib.rs +++ b/interop_binaries/src/lib.rs @@ -111,7 +111,7 @@ where pub enum VdafObject { Prio3Count, Prio3Sum { - bits: NumberAsString, + max_measurement: NumberAsString, }, Prio3SumVec { bits: NumberAsString, @@ -142,8 +142,8 @@ impl From for VdafObject { match vdaf { VdafInstance::Prio3Count => VdafObject::Prio3Count, - VdafInstance::Prio3Sum { bits } => VdafObject::Prio3Sum { - bits: NumberAsString(bits), + VdafInstance::Prio3Sum { max_measurement } => VdafObject::Prio3Sum { + max_measurement: NumberAsString(max_measurement), }, VdafInstance::Prio3SumVec { @@ -200,7 +200,9 @@ impl From for VdafInstance { match vdaf { VdafObject::Prio3Count => VdafInstance::Prio3Count, - VdafObject::Prio3Sum { bits } => VdafInstance::Prio3Sum { bits: bits.0 }, + VdafObject::Prio3Sum { max_measurement } => VdafInstance::Prio3Sum { + max_measurement: max_measurement.0, + }, VdafObject::Prio3SumVec { bits, diff --git a/interop_binaries/tests/end_to_end.rs b/interop_binaries/tests/end_to_end.rs index 7a880c18a..9276acf8b 100644 --- a/interop_binaries/tests/end_to_end.rs +++ b/interop_binaries/tests/end_to_end.rs @@ -612,7 +612,7 @@ async fn e2e_prio3_sum() { let result = run( "e2e_prio3_sum", QueryKind::TimeInterval, - json!({"type": "Prio3Sum", "bits": "64"}), + json!({"type": "Prio3Sum", "max_measurement": "255"}), &[ json!("0"), json!("10"), diff --git a/tools/src/bin/collect.rs b/tools/src/bin/collect.rs index 11e45d203..9805ec2a5 100644 --- a/tools/src/bin/collect.rs +++ b/tools/src/bin/collect.rs @@ -314,9 +314,12 @@ struct Options { /// used with --vdaf=histogram #[clap(long, help_heading = "VDAF Algorithm and Parameters")] length: Option, - /// Bit length of measurements, for use with --vdaf=sum and --vdaf=sumvec + /// Bit length of measurements, for use with --vdaf=sumvec #[clap(long, help_heading = "VDAF Algorithm and Parameters")] bits: Option, + /// Maximum measurement value, for use with --vdaf=sum + #[clap(long, help_heading = "VDAF Algorithm and Parameters")] + max_measurement: Option, #[clap(flatten)] query: QueryOptions, @@ -450,18 +453,24 @@ macro_rules! options_query_dispatch { macro_rules! options_vdaf_dispatch { ($options:expr, ($vdaf:ident) => $body:tt) => { - match ($options.vdaf, $options.length, $options.bits) { - (VdafType::Count, None, None) => { + match ( + $options.vdaf, + $options.length, + $options.bits, + $options.max_measurement, + ) { + (VdafType::Count, None, None, None) => { let $vdaf = Prio3::new_count(2).map_err(|err| Error::Anyhow(err.into()))?; let body = $body; body } - (VdafType::Sum, None, Some(bits)) => { - let $vdaf = Prio3::new_sum(2, bits).map_err(|err| Error::Anyhow(err.into()))?; + (VdafType::Sum, None, None, Some(max_measurement)) => { + let $vdaf = Prio3::new_sum(2, u128::from(max_measurement)) + .map_err(|err| Error::Anyhow(err.into()))?; let body = $body; body } - (VdafType::SumVec, Some(length), Some(bits)) => { + (VdafType::SumVec, Some(length), Some(bits), None) => { // We can take advantage of the fact that Prio3SumVec unsharding does not use the // chunk_length parameter and avoid asking the user for it. let $vdaf = Prio3::new_sum_vec(2, bits, length, 1) @@ -469,7 +478,7 @@ macro_rules! options_vdaf_dispatch { let body = $body; body } - (VdafType::Histogram, Some(length), None) => { + (VdafType::Histogram, Some(length), None, None) => { // We can take advantage of the fact that Prio3Histogram unsharding does not use the // chunk_length parameter and avoid asking the user for it. let $vdaf = @@ -478,7 +487,7 @@ macro_rules! options_vdaf_dispatch { body } #[cfg(feature = "fpvec_bounded_l2")] - (VdafType::FixedPoint16BitBoundedL2VecSum, Some(length), None) => { + (VdafType::FixedPoint16BitBoundedL2VecSum, Some(length), None, None) => { let $vdaf: Prio3FixedPointBoundedL2VecSum> = Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) .map_err(|err| Error::Anyhow(err.into()))?; @@ -486,7 +495,7 @@ macro_rules! options_vdaf_dispatch { body } #[cfg(feature = "fpvec_bounded_l2")] - (VdafType::FixedPoint32BitBoundedL2VecSum, Some(length), None) => { + (VdafType::FixedPoint32BitBoundedL2VecSum, Some(length), None, None) => { let $vdaf: Prio3FixedPointBoundedL2VecSum> = Prio3::new_fixedpoint_boundedl2_vec_sum(2, length) .map_err(|err| Error::Anyhow(err.into()))?; @@ -758,6 +767,7 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, + max_measurement: None, query: QueryOptions { batch_interval_start: Some(1_000_000), batch_interval_duration: Some(1_000), @@ -991,6 +1001,7 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, + max_measurement: None, query: QueryOptions { batch_interval_start: None, batch_interval_duration: None, @@ -1332,6 +1343,7 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, + max_measurement: None, query: QueryOptions { batch_interval_start: Some(1_000_000), batch_interval_duration: Some(1_000), @@ -1415,6 +1427,7 @@ mod tests { vdaf: VdafType::Count, length: None, bits: None, + max_measurement: None, query: QueryOptions { batch_interval_start: Some(1_000_000), batch_interval_duration: Some(1_000), diff --git a/tools/tests/cmd/collect.trycmd b/tools/tests/cmd/collect.trycmd index 5c6e7a7f4..69594f9be 100644 --- a/tools/tests/cmd/collect.trycmd +++ b/tools/tests/cmd/collect.trycmd @@ -74,7 +74,10 @@ VDAF Algorithm and Parameters: Number of vector elements, when used with --vdaf=sumvec or number of histogram buckets, when used with --vdaf=histogram --bits - Bit length of measurements, for use with --vdaf=sum and --vdaf=sumvec + Bit length of measurements, for use with --vdaf=sumvec + + --max-measurement + Maximum measurement value, for use with --vdaf=sum Collect Request Parameters (Time Interval): --batch-interval-start