diff --git a/Cargo.lock b/Cargo.lock index 55c704343..b50604a8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -965,6 +965,8 @@ dependencies = [ "axum", "axum-extra", "bytes", + "capnp", + "capnpc", "chrono", "constcat", "daphne", @@ -997,6 +999,7 @@ dependencies = [ "tracing-core", "tracing-subscriber", "url", + "wasm-bindgen", "webpki", "worker", ] diff --git a/Cargo.toml b/Cargo.toml index 51cefcea0..a0dd23562 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -90,8 +90,9 @@ tracing = "0.1.40" tracing-core = "0.1.32" tracing-subscriber = "0.3.18" url = { version = "2.5.4", features = ["serde"] } +wasm-bindgen = "0.2.99" webpki = "0.22.4" -worker = { version = "0.5", features = ["http"] } +worker = "0.5" x509-parser = "0.15.1" [workspace.dependencies.sentry] diff --git a/Makefile b/Makefile index 0e5780680..f856e7e93 100644 --- a/Makefile +++ b/Makefile @@ -21,6 +21,16 @@ helper: -c ./crates/daphne-server/examples/configuration-helper.toml h: helper +compute-offload: + RUST_LOG=hyper=off,debug cargo run \ + --profile release-symbols \ + --features test-utils \ + --example service \ + -- \ + -c ./crates/daphne-server/examples/configuration-cpu-offload.toml +co: compute-offload + + helper-worker: cd ./crates/daphne-worker-test/ && \ wrangler dev -c wrangler.aggregator.toml --port 8788 -e helper diff --git a/crates/dapf/src/acceptance/mod.rs b/crates/dapf/src/acceptance/mod.rs index fb108b6a8..68281a11a 100644 --- a/crates/dapf/src/acceptance/mod.rs +++ b/crates/dapf/src/acceptance/mod.rs @@ -27,6 +27,7 @@ use daphne::{ BatchId, BatchSelector, PartialBatchSelector, TaskId, }, metrics::DaphneMetrics, + protocol::ReadyAggregationJobResp, testing::report_generator::ReportGenerator, vdaf::VdafConfig, DapAggregateShare, DapAggregateSpan, DapAggregationParam, DapBatchMode, DapMeasurement, @@ -511,7 +512,7 @@ impl Test { let _guard = load_control.wait().await; info!("Starting AggregationJobInitReq"); let start = Instant::now(); - let agg_job_resp = self + let mut agg_job_resp = self .http_client .submit_aggregation_job_init_req( self.helper_url.join(&format!( @@ -528,14 +529,42 @@ impl Test { ) .await?; let duration = start.elapsed(); - info!("Finished AggregationJobInitReq in {duration:#?}"); + info!("Finished submitting AggregationJobInitReq in {duration:#?}"); + let mut poll_count = 1; + let ready = loop { + agg_job_resp = match agg_job_resp { + messages::AggregationJobResp::Ready { prep_resps } => { + if poll_count != 1 { + info!( + "Finished polling for AggregationJobResp after {:#?}", + start.elapsed() + ); + } + break ReadyAggregationJobResp { prep_resps }; + } + messages::AggregationJobResp::Processing => { + if poll_count == 1 { + info!("Polling for AggregationJobResp"); + } + tokio::time::sleep(Duration::from_millis(poll_count * 200)).await; + poll_count += 1; + self.http_client + .poll_aggregation_job_init( + self.helper_url + .join(&format!("tasks/{task_id}/aggregation_jobs/{agg_job_id}"))?, + task_config.version, + functions::helper::Options { + taskprov_advertisement: taskprov_advertisement.as_ref(), + bearer_token: self.bearer_token.as_ref(), + }, + ) + .await? + } + }; + }; - let agg_share_span = task_config.consume_agg_job_resp( - task_id, - agg_job_state, - agg_job_resp.unwrap_ready(), // TODO: implement polling - self.metrics(), - )?; + let agg_share_span = + task_config.consume_agg_job_resp(task_id, agg_job_state, ready, self.metrics())?; let aggregated_report_count = agg_share_span .iter() diff --git a/crates/daphne-server/docker/example-service.Dockerfile b/crates/daphne-server/docker/example-service.Dockerfile index d18e1d749..0aae64c15 100644 --- a/crates/daphne-server/docker/example-service.Dockerfile +++ b/crates/daphne-server/docker/example-service.Dockerfile @@ -1,7 +1,7 @@ # Copyright (c) 2024 Cloudflare, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -FROM rust:1.80-bookworm AS builder +FROM rust:1.84.1-bookworm AS builder RUN apt update && \ apt install -y \ diff --git a/crates/daphne-server/src/roles/aggregator.rs b/crates/daphne-server/src/roles/aggregator.rs index 7018bb82e..8bfec8558 100644 --- a/crates/daphne-server/src/roles/aggregator.rs +++ b/crates/daphne-server/src/roles/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::{future::ready, num::NonZeroUsize, ops::Range, time::SystemTime}; @@ -79,6 +79,7 @@ impl DapAggregator for crate::App { #[tracing::instrument(skip(self))] async fn get_agg_share( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { @@ -115,6 +116,7 @@ impl DapAggregator for crate::App { #[tracing::instrument(skip(self))] async fn mark_collected( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result<(), DapError> { @@ -255,6 +257,7 @@ impl DapAggregator for crate::App { async fn is_batch_overlapping( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { @@ -288,7 +291,12 @@ impl DapAggregator for crate::App { ) } - async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { + async fn batch_exists( + &self, + _version: DapVersion, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result { let task_config = self .get_task_config_for(task_id) .await? diff --git a/crates/daphne-server/src/roles/helper.rs b/crates/daphne-server/src/roles/helper.rs index 12fe1f8ae..822ca510d 100644 --- a/crates/daphne-server/src/roles/helper.rs +++ b/crates/daphne-server/src/roles/helper.rs @@ -3,7 +3,8 @@ use axum::async_trait; use daphne::{ - messages::{AggregationJobId, TaskId}, + fatal_error, + messages::{AggregationJobId, AggregationJobResp, TaskId}, roles::{helper::AggregationJobRequestHash, DapHelper}, DapError, DapVersion, }; @@ -20,4 +21,13 @@ impl DapHelper for crate::App { // the server implementation can't check for this Ok(()) } + + async fn poll_aggregated( + &self, + _version: DapVersion, + _task_id: &TaskId, + _agg_job_id: &AggregationJobId, + ) -> Result { + Err(fatal_error!(err = "polling not implemented")) + } } diff --git a/crates/daphne-server/src/roles/leader.rs b/crates/daphne-server/src/roles/leader.rs index dde6f629c..800dbfa89 100644 --- a/crates/daphne-server/src/roles/leader.rs +++ b/crates/daphne-server/src/roles/leader.rs @@ -121,6 +121,10 @@ impl DapLeader for crate::App { { self.send_http(meta, Method::PUT, url, payload).await } + + async fn send_http_get(&self, meta: DapRequestMeta, url: Url) -> Result { + self.send_http(meta, Method::PUT, url, ()).await + } } impl crate::App { diff --git a/crates/daphne-server/src/router/extractor.rs b/crates/daphne-server/src/router/extractor.rs index 51d089691..9ea492529 100644 --- a/crates/daphne-server/src/router/extractor.rs +++ b/crates/daphne-server/src/router/extractor.rs @@ -72,12 +72,12 @@ impl DecodeFromDapHttpBody for HashedAggregationJobReq { } } -/// Using `()` ignores the body of a request. impl DecodeFromDapHttpBody for CollectionPollReq { fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { Ok(Self) } } + /// Using `()` ignores the body of a request. impl DecodeFromDapHttpBody for () { fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { diff --git a/crates/daphne-service-utils/build.rs b/crates/daphne-service-utils/build.rs index 317e18715..0f3c692f7 100644 --- a/crates/daphne-service-utils/build.rs +++ b/crates/daphne-service-utils/build.rs @@ -11,7 +11,10 @@ fn main() { #[cfg(feature = "durable_requests")] compiler .file("./src/durable_requests/durable_request.capnp") - .file("./src/durable_requests/bindings/aggregation_job_store.capnp"); + .file("./src/durable_requests/bindings/aggregation_job_store.capnp") + .file("./src/durable_requests/bindings/aggregate_store_v2.capnp") + .file("./src/durable_requests/bindings/agg_job_response_store.capnp") + .file("./src/durable_requests/bindings/replay_checker.capnp"); #[cfg(feature = "compute-offload")] compiler.file("./src/compute_offload/compute_offload.capnp"); diff --git a/crates/daphne-service-utils/src/capnproto/base.capnp b/crates/daphne-service-utils/src/capnproto/base.capnp index ab8668f4b..0329c9758 100644 --- a/crates/daphne-service-utils/src/capnproto/base.capnp +++ b/crates/daphne-service-utils/src/capnproto/base.capnp @@ -22,13 +22,29 @@ struct U8L16 @0x9e3f65b13f71cfcb { snd @1 :UInt64; } -struct PartialBatchSelector { +struct PartialBatchSelector @0xae86084e56c22fc0 { union { timeInterval @0 :Void; leaderSelectedByBatchId @1 :BatchId; } } +enum ReportError @0xa76428617779e659 { + reserved @0; + batchCollected @1; + reportReplayed @2; + reportDropped @3; + hpkeUnknownConfigId @4; + hpkeDecryptError @5; + vdafPrepError @6; + batchSaturated @7; + taskExpired @8; + invalidMessage @9; + reportTooEarly @10; + taskNotStarted @11; +} + + using ReportId = U8L16; using BatchId = U8L32; using TaskId = U8L32; diff --git a/crates/daphne-service-utils/src/capnproto/mod.rs b/crates/daphne-service-utils/src/capnproto/mod.rs index ed7e7c989..ea2954d6f 100644 --- a/crates/daphne-service-utils/src/capnproto/mod.rs +++ b/crates/daphne-service-utils/src/capnproto/mod.rs @@ -4,6 +4,7 @@ use crate::base_capnp::{self, partial_batch_selector, u8_l16, u8_l32}; use capnp::struct_list; use capnp::traits::{FromPointerBuilder, FromPointerReader}; +use daphne::messages; use daphne::{ messages::{AggregationJobId, BatchId, PartialBatchSelector, ReportId, TaskId}, DapVersion, @@ -204,6 +205,44 @@ impl CapnprotoPayloadDecode for PartialBatchSelector { } } +impl From for base_capnp::ReportError { + fn from(failure: messages::ReportError) -> Self { + match failure { + messages::ReportError::Reserved => Self::Reserved, + messages::ReportError::BatchCollected => Self::BatchCollected, + messages::ReportError::ReportReplayed => Self::ReportReplayed, + messages::ReportError::ReportDropped => Self::ReportDropped, + messages::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, + messages::ReportError::HpkeDecryptError => Self::HpkeDecryptError, + messages::ReportError::VdafPrepError => Self::VdafPrepError, + messages::ReportError::BatchSaturated => Self::BatchSaturated, + messages::ReportError::TaskExpired => Self::TaskExpired, + messages::ReportError::InvalidMessage => Self::InvalidMessage, + messages::ReportError::ReportTooEarly => Self::ReportTooEarly, + messages::ReportError::TaskNotStarted => Self::TaskNotStarted, + } + } +} + +impl From for messages::ReportError { + fn from(val: base_capnp::ReportError) -> Self { + match val { + base_capnp::ReportError::Reserved => Self::Reserved, + base_capnp::ReportError::BatchCollected => Self::BatchCollected, + base_capnp::ReportError::ReportReplayed => Self::ReportReplayed, + base_capnp::ReportError::ReportDropped => Self::ReportDropped, + base_capnp::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, + base_capnp::ReportError::HpkeDecryptError => Self::HpkeDecryptError, + base_capnp::ReportError::VdafPrepError => Self::VdafPrepError, + base_capnp::ReportError::BatchSaturated => Self::BatchSaturated, + base_capnp::ReportError::TaskExpired => Self::TaskExpired, + base_capnp::ReportError::InvalidMessage => Self::InvalidMessage, + base_capnp::ReportError::ReportTooEarly => Self::ReportTooEarly, + base_capnp::ReportError::TaskNotStarted => Self::TaskNotStarted, + } + } +} + pub fn encode_list(list: I, mut builder: struct_list::Builder<'_, O>) where I: IntoIterator, diff --git a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp index 581baf0a7..8268081ca 100644 --- a/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp +++ b/crates/daphne-service-utils/src/compute_offload/compute_offload.capnp @@ -3,8 +3,6 @@ @0xd932f3d934afce3b; -# Utilities - using Base = import "../capnproto/base.capnp"; using VdafConfig = Text; # json encoded @@ -94,27 +92,11 @@ struct PrepareInit @0x8192568cb3d03f59 { -struct InitializedReports { - struct InitializedReport { +struct InitializedReports @0xf36341397ae4a146 { + struct InitializedReport @0xfa833aa6b5d03d6d { using VdafPrepShare = Data; using VdafPrepState = Data; - enum ReportError { - reserved @0; - batchCollected @1; - reportReplayed @2; - reportDropped @3; - hpkeUnknownConfigId @4; - hpkeDecryptError @5; - vdafPrepError @6; - batchSaturated @7; - taskExpired @8; - invalidMessage @9; - reportTooEarly @10; - taskNotStarted @11; - } - - union { ready :group { metadata @0 :ReportMetadata; @@ -125,7 +107,7 @@ struct InitializedReports { } rejected :group { metadata @5 :ReportMetadata; - failure @6 :ReportError; + failure @6 :Base.ReportError; } } } diff --git a/crates/daphne-service-utils/src/compute_offload/mod.rs b/crates/daphne-service-utils/src/compute_offload/mod.rs index 142c1d963..4f5727f8f 100644 --- a/crates/daphne-service-utils/src/compute_offload/mod.rs +++ b/crates/daphne-service-utils/src/compute_offload/mod.rs @@ -479,44 +479,6 @@ impl CapnprotoPayloadDecode for InitializedReports { } } -impl From for initialized_report::ReportError { - fn from(failure: messages::ReportError) -> Self { - match failure { - messages::ReportError::Reserved => Self::Reserved, - messages::ReportError::BatchCollected => Self::BatchCollected, - messages::ReportError::ReportReplayed => Self::ReportReplayed, - messages::ReportError::ReportDropped => Self::ReportDropped, - messages::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, - messages::ReportError::HpkeDecryptError => Self::HpkeDecryptError, - messages::ReportError::VdafPrepError => Self::VdafPrepError, - messages::ReportError::BatchSaturated => Self::BatchSaturated, - messages::ReportError::TaskExpired => Self::TaskExpired, - messages::ReportError::InvalidMessage => Self::InvalidMessage, - messages::ReportError::ReportTooEarly => Self::ReportTooEarly, - messages::ReportError::TaskNotStarted => Self::TaskNotStarted, - } - } -} - -impl From for messages::ReportError { - fn from(val: initialized_report::ReportError) -> Self { - match val { - initialized_report::ReportError::Reserved => Self::Reserved, - initialized_report::ReportError::BatchCollected => Self::BatchCollected, - initialized_report::ReportError::ReportReplayed => Self::ReportReplayed, - initialized_report::ReportError::ReportDropped => Self::ReportDropped, - initialized_report::ReportError::HpkeUnknownConfigId => Self::HpkeUnknownConfigId, - initialized_report::ReportError::HpkeDecryptError => Self::HpkeDecryptError, - initialized_report::ReportError::VdafPrepError => Self::VdafPrepError, - initialized_report::ReportError::BatchSaturated => Self::BatchSaturated, - initialized_report::ReportError::TaskExpired => Self::TaskExpired, - initialized_report::ReportError::InvalidMessage => Self::InvalidMessage, - initialized_report::ReportError::ReportTooEarly => Self::ReportTooEarly, - initialized_report::ReportError::TaskNotStarted => Self::TaskNotStarted, - } - } -} - fn to_capnp(e: E) -> capnp::Error { capnp::Error { kind: capnp::ErrorKind::Failed, diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp new file mode 100644 index 000000000..8720351d1 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.capnp @@ -0,0 +1,23 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xd30da336463f3205; + +using Base = import "../../capnproto/base.capnp"; + +struct AggregationJobResponse @0xebda3ce03fce7e72 { + struct PrepareRespVar @0xc41a0ca7156794f0 { + union { + continue @0 :Data; + reject @1 :Base.ReportError; + } + } + + struct PrepareResp @0xc8b6a95ad17a2152 { + reportId @0 :Base.ReportId; + var @1 :PrepareRespVar; + } + + prepResps @0 :List(PrepareResp); +} + diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs new file mode 100644 index 000000000..0fe8d20ce --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/agg_job_response_store.rs @@ -0,0 +1,122 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use daphne::{ + messages::{AggregationJobId, PrepareResp, PrepareRespVar, TaskId}, + protocol::ReadyAggregationJobResp, + DapVersion, +}; + +use crate::{ + agg_job_response_store_capnp::aggregation_job_response, + capnproto::{ + decode_list, encode_list, usize_to_capnp_len, CapnprotoPayloadDecode, + CapnprotoPayloadEncode, + }, + durable_requests::ObjectIdFrom, +}; + +super::define_do_binding! { + const BINDING = "AGGREGATE_JOB_RESULT_STORE"; + enum Command { + Get = "/get", + Put = "/put", + } + + fn name( + (version, task_id, agg_job_id): + (DapVersion, &'n TaskId, &'n AggregationJobId) + ) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("{version}/task/{task_id}/agg_job/{agg_job_id}")) + } +} + +impl CapnprotoPayloadEncode for ReadyAggregationJobResp { + type Builder<'a> = aggregation_job_response::Builder<'a>; + + fn encode_to_builder(&self, builder: Self::Builder<'_>) { + let Self { prep_resps } = self; + encode_list( + prep_resps, + builder.init_prep_resps(usize_to_capnp_len(prep_resps.len())), + ); + } +} + +impl CapnprotoPayloadEncode for PrepareResp { + type Builder<'a> = aggregation_job_response::prepare_resp::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { report_id, var } = self; + report_id.encode_to_builder(builder.reborrow().init_report_id()); + let mut builder = builder.init_var(); + match var { + PrepareRespVar::Continue(vec) => builder.set_continue(vec), + PrepareRespVar::Reject(report_error) => builder.set_reject((*report_error).into()), + } + } +} + +impl CapnprotoPayloadDecode for PrepareResp { + type Reader<'a> = aggregation_job_response::prepare_resp::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + report_id: <_>::decode_from_reader(reader.get_report_id()?)?, + var: match reader.get_var()?.which()? { + aggregation_job_response::prepare_resp_var::Which::Continue(data) => { + PrepareRespVar::Continue(data?.to_vec()) + } + aggregation_job_response::prepare_resp_var::Which::Reject(report_error) => { + PrepareRespVar::Reject(report_error?.into()) + } + }, + }) + } +} + +impl CapnprotoPayloadDecode for ReadyAggregationJobResp { + type Reader<'a> = aggregation_job_response::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + prep_resps: decode_list::(reader.get_prep_resps()?)?, + }) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _}; + use daphne::messages::ReportId; + use rand::{thread_rng, Rng}; + + fn gen_agg_job_resp() -> ReadyAggregationJobResp { + ReadyAggregationJobResp { + prep_resps: vec![ + PrepareResp { + report_id: ReportId(thread_rng().gen()), + var: PrepareRespVar::Continue(vec![1, 2, 3]), + }, + PrepareResp { + report_id: ReportId(thread_rng().gen()), + var: PrepareRespVar::Reject(daphne::messages::ReportError::InvalidMessage), + }, + ], + } + } + + #[test] + fn serialization_deserialization_round_trip() { + let this = gen_agg_job_resp(); + let other = ReadyAggregationJobResp::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs index 0f5b1181d..960fbb94f 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::collections::HashSet; @@ -93,66 +93,7 @@ impl CapnprotoPayloadEncode for AggregateStoreMergeReq { report.set_high(high); } } - { - let mut agg_share_delta_packet = builder.reborrow().init_agg_share_delta(); - agg_share_delta_packet.set_report_count(agg_share_delta.report_count); - agg_share_delta_packet.set_min_time(agg_share_delta.min_time); - agg_share_delta_packet.set_max_time(agg_share_delta.max_time); - { - let checksum = agg_share_delta_packet - .reborrow() - .init_checksum(agg_share_delta.checksum.len().try_into().unwrap()); - checksum.copy_from_slice(&agg_share_delta.checksum); - } - { - macro_rules! make_encode { - ($func_name:ident, $agg_share_type:ident, $field_trait:ident) => { - fn $func_name<'b, F, B, const ENCODED_SIZE: usize>( - field: &$agg_share_type, - get_bytes: B, - ) where - F: $field_trait + Into<[u8; ENCODED_SIZE]>, - B: FnOnce(u32) -> &'b mut [u8], - { - let mut bytes = get_bytes( - (F::ENCODED_SIZE * field.as_ref().len()) - .try_into() - .expect("trying to encode a buffer longer than u32::MAX"), - ); - for f in field.as_ref() { - let f: [u8; ENCODED_SIZE] = (*f).into(); - bytes[..ENCODED_SIZE].copy_from_slice(&f); - bytes = &mut bytes[ENCODED_SIZE..]; - } - } - }; - } - make_encode!(encode_draft09, AggregateShareDraft09, FieldElementDraft09); - make_encode!(encode, AggregateShare, FieldElement); - let mut data = agg_share_delta_packet.init_data(); - match &self.agg_share_delta.data { - Some(VdafAggregateShare::Field64Draft09(field)) => { - encode_draft09(field, |len| data.init_field64_draft09(len)); - } - Some(VdafAggregateShare::Field128Draft09(field)) => { - encode_draft09(field, |len| data.init_field128_draft09(len)); - } - Some(VdafAggregateShare::Field32Draft09(field)) => { - encode_draft09(field, |len| data.init_field_prio2_draft09(len)); - } - Some(VdafAggregateShare::Field64(field)) => { - encode(field, |len| data.init_field64(len)); - } - Some(VdafAggregateShare::Field128(field)) => { - encode(field, |len| data.init_field128(len)); - } - Some(VdafAggregateShare::Field32(field)) => { - encode(field, |len| data.init_field_prio2(len)); - } - None => data.set_none(()), - }; - } - } + agg_share_delta.encode_to_builder(builder.reborrow().init_agg_share_delta()); { let AggregateStoreMergeOptions { skip_replay_protection, @@ -167,76 +108,8 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { type Reader<'a> = aggregate_store_merge_req::Reader<'a>; fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { - let agg_share_delta = { - let agg_share_delta = reader.get_agg_share_delta()?; - let data = { - macro_rules! make_decode { - ($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => { - fn $func_name(fields: &[u8]) -> capnp::Result<$agg_share_type> - where - F: $field_trait + for<'s> TryFrom<&'s [u8], Error = $field_error>, - { - let iter = fields.chunks_exact(F::ENCODED_SIZE); - if let length @ 1.. = iter.remainder().len() { - return Err(capnp::Error { - kind: capnp::ErrorKind::Failed, - extra: format!( - "leftover bytes still present in buffer: {length}" - ), - }); - } - Ok($agg_share_type::from( - iter.map(|f| f.try_into().unwrap()).collect::>(), - )) - } - }; - } - make_decode!( - decode_draft09, - AggregateShareDraft09, - FieldElementDraft09, - FieldErrorDraft09 - ); - make_decode!(decode, AggregateShare, FieldElement, FieldError); - match agg_share_delta.get_data().which()? { - dap_aggregate_share::data::Which::Field64Draft09(field) => { - Some(VdafAggregateShare::Field64Draft09(decode_draft09(field?)?)) - } - dap_aggregate_share::data::Which::Field128Draft09(field) => { - Some(VdafAggregateShare::Field128Draft09(decode_draft09(field?)?)) - } - dap_aggregate_share::data::Which::FieldPrio2Draft09(field) => { - Some(VdafAggregateShare::Field32Draft09(decode_draft09(field?)?)) - } - - dap_aggregate_share::data::Which::Field64(field) => { - Some(VdafAggregateShare::Field64(decode(field?)?)) - } - dap_aggregate_share::data::Which::Field128(field) => { - Some(VdafAggregateShare::Field128(decode(field?)?)) - } - dap_aggregate_share::data::Which::FieldPrio2(field) => { - Some(VdafAggregateShare::Field32(decode(field?)?)) - } - dap_aggregate_share::data::Which::None(()) => None, - } - }; - DapAggregateShare { - report_count: agg_share_delta.get_report_count(), - min_time: agg_share_delta.get_min_time(), - max_time: agg_share_delta.get_max_time(), - checksum: agg_share_delta - .get_checksum()? - .try_into() - .map_err(|_| capnp::Error { - kind: capnp::ErrorKind::Failed, - extra: "checksum had unexpected size".into(), - })?, - data, - } - }; - let contained_reports = { - reader + Ok(Self { + contained_reports: reader .get_contained_reports()? .into_iter() .map(|report| { @@ -248,11 +121,8 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { buffer[8..].copy_from_slice(&high.to_le_bytes()); ReportId(buffer) }) - .collect() - }; - Ok(Self { - contained_reports, - agg_share_delta, + .collect(), + agg_share_delta: <_>::decode_from_reader(reader.get_agg_share_delta()?)?, options: AggregateStoreMergeOptions { skip_replay_protection: reader.get_options()?.get_skip_replay_protection(), }, @@ -260,6 +130,138 @@ impl CapnprotoPayloadDecode for AggregateStoreMergeReq { } } +impl CapnprotoPayloadEncode for DapAggregateShare { + type Builder<'a> = dap_aggregate_share::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + builder.set_report_count(self.report_count); + builder.set_min_time(self.min_time); + builder.set_max_time(self.max_time); + builder.set_checksum(&self.checksum); + { + macro_rules! make_encode { + ($func_name:ident, $agg_share_type:ident, $field_trait:ident) => { + fn $func_name<'b, F, B, const ENCODED_SIZE: usize>( + field: &$agg_share_type, + get_bytes: B, + ) where + F: $field_trait + Into<[u8; ENCODED_SIZE]>, + B: FnOnce(u32) -> &'b mut [u8], + { + let mut bytes = get_bytes( + (F::ENCODED_SIZE * field.as_ref().len()) + .try_into() + .expect("trying to encode a buffer longer than u32::MAX"), + ); + for f in field.as_ref() { + let f: [u8; ENCODED_SIZE] = (*f).into(); + bytes[..ENCODED_SIZE].copy_from_slice(&f); + bytes = &mut bytes[ENCODED_SIZE..]; + } + } + }; + } + make_encode!(encode_draft09, AggregateShareDraft09, FieldElementDraft09); + make_encode!(encode, AggregateShare, FieldElement); + let mut data = builder.init_data(); + match &self.data { + Some(VdafAggregateShare::Field64Draft09(field)) => { + encode_draft09(field, |len| data.init_field64_draft09(len)); + } + Some(VdafAggregateShare::Field128Draft09(field)) => { + encode_draft09(field, |len| data.init_field128_draft09(len)); + } + Some(VdafAggregateShare::Field32Draft09(field)) => { + encode_draft09(field, |len| data.init_field_prio2_draft09(len)); + } + Some(VdafAggregateShare::Field64(field)) => { + encode(field, |len| data.init_field64(len)); + } + Some(VdafAggregateShare::Field128(field)) => { + encode(field, |len| data.init_field128(len)); + } + Some(VdafAggregateShare::Field32(field)) => { + encode(field, |len| data.init_field_prio2(len)); + } + None => data.set_none(()), + }; + } + } +} + +impl CapnprotoPayloadDecode for DapAggregateShare { + type Reader<'a> = dap_aggregate_share::Reader<'a>; + + fn decode_from_reader(agg_share_delta: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + let data = { + macro_rules! make_decode { + ($func_name:ident, $agg_share_type:ident, $field_trait:ident, $field_error:ident) => { + fn $func_name(fields: &[u8]) -> capnp::Result<$agg_share_type> + where + F: $field_trait + for<'s> TryFrom<&'s [u8], Error = $field_error>, + { + let iter = fields.chunks_exact(F::ENCODED_SIZE); + if let length @ 1.. = iter.remainder().len() { + return Err(capnp::Error { + kind: capnp::ErrorKind::Failed, + extra: format!("leftover bytes still present in buffer: {length}"), + }); + } + Ok($agg_share_type::from( + iter.map(|f| f.try_into().unwrap()).collect::>(), + )) + } + }; + } + make_decode!( + decode_draft09, + AggregateShareDraft09, + FieldElementDraft09, + FieldErrorDraft09 + ); + make_decode!(decode, AggregateShare, FieldElement, FieldError); + match agg_share_delta.get_data().which()? { + dap_aggregate_share::data::Which::Field64Draft09(field) => { + Some(VdafAggregateShare::Field64Draft09(decode_draft09(field?)?)) + } + dap_aggregate_share::data::Which::Field128Draft09(field) => { + Some(VdafAggregateShare::Field128Draft09(decode_draft09(field?)?)) + } + dap_aggregate_share::data::Which::FieldPrio2Draft09(field) => { + Some(VdafAggregateShare::Field32Draft09(decode_draft09(field?)?)) + } + + dap_aggregate_share::data::Which::Field64(field) => { + Some(VdafAggregateShare::Field64(decode(field?)?)) + } + dap_aggregate_share::data::Which::Field128(field) => { + Some(VdafAggregateShare::Field128(decode(field?)?)) + } + dap_aggregate_share::data::Which::FieldPrio2(field) => { + Some(VdafAggregateShare::Field32(decode(field?)?)) + } + dap_aggregate_share::data::Which::None(()) => None, + } + }; + Ok(Self { + report_count: agg_share_delta.get_report_count(), + min_time: agg_share_delta.get_min_time(), + max_time: agg_share_delta.get_max_time(), + checksum: agg_share_delta + .get_checksum()? + .try_into() + .map_err(|_| capnp::Error { + kind: capnp::ErrorKind::Failed, + extra: "checksum had unexpected size".into(), + })?, + data, + }) + } +} + #[derive(Debug, Serialize, Deserialize)] pub enum AggregateStoreMergeResp { Ok, diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp new file mode 100644 index 000000000..37f5f883f --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.capnp @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0x822b0e344bf68531; + +using Base = import "../../capnproto/base.capnp"; + +struct PutRequest @0xbabd9e0f2a99569a { + aggShareDelta @0 :import "../durable_request.capnp".DapAggregateShare; + aggJobId @1 :Base.AggregationJobId; +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs new file mode 100644 index 000000000..d9ed1057e --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregate_store_v2.rs @@ -0,0 +1,195 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use daphne::{ + messages::{AggregationJobId, TaskId}, + DapAggregateShare, DapBatchBucket, DapVersion, +}; + +use crate::{ + aggregate_store_v2_capnp, + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadEncode}, + durable_requests::ObjectIdFrom, +}; + +super::define_do_binding! { + const BINDING = "AGGREGATE_STORE"; + enum Command { + Get = "/get", + Put = "/put", + MarkCollected = "/mark-collected", + CheckCollected = "/check-collected", + AggregateShareCount = "/aggregate-share-count", + } + + fn name( + (version, task_id, bucket): + (DapVersion, &'n TaskId, &'n DapBatchBucket) + ) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("{version}/task/{task_id}/batch_bucket/{bucket}")) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct PutRequest { + pub agg_share_delta: DapAggregateShare, + pub agg_job_id: AggregationJobId, +} + +impl CapnprotoPayloadEncode for PutRequest { + type Builder<'a> = aggregate_store_v2_capnp::put_request::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { + agg_share_delta, + agg_job_id, + } = self; + agg_share_delta.encode_to_builder(builder.reborrow().init_agg_share_delta()); + agg_job_id.encode_to_builder(builder.reborrow().init_agg_job_id()); + } +} + +impl CapnprotoPayloadDecode for PutRequest { + type Reader<'a> = aggregate_store_v2_capnp::put_request::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result { + Ok(Self { + agg_share_delta: <_>::decode_from_reader(reader.get_agg_share_delta()?)?, + agg_job_id: <_>::decode_from_reader(reader.get_agg_job_id()?)?, + }) + } +} + +#[cfg(test)] +mod test { + use prio::{ + codec::Decode, + field::{Field128, Field64, FieldElement, FieldPrio2}, + vdaf::AggregateShare, + }; + use prio_draft09::{ + codec::Decode as DecodeDraft09, + field::{ + Field128 as Field128Draft09, Field64 as Field64Draft09, + FieldElement as FieldElementDraft09, FieldPrio2 as FieldPrio2Draft09, + }, + vdaf::AggregateShare as AggregateShareDraft09, + }; + use rand::{thread_rng, Rng}; + + use crate::capnproto::{CapnprotoPayloadDecodeExt as _, CapnprotoPayloadEncodeExt as _}; + + use super::*; + use daphne::vdaf::VdafAggregateShare; + + #[test] + fn serialization_deserialization_round_trip_draft09() { + let mut rng = thread_rng(); + for len in 0..20 { + let test_data = [ + VdafAggregateShare::Field64Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + Field64Draft09::get_decoded( + &rng.gen::<[_; Field64Draft09::ENCODED_SIZE]>(), + ) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field128Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + Field128Draft09::get_decoded( + &rng.gen::<[_; Field128Draft09::ENCODED_SIZE]>(), + ) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field32Draft09(AggregateShareDraft09::from( + (0..len) + .map(|_| { + // idk how to consistently generate a valid FieldPrio2 value, so I just + // retry until I hit a valid one. Doesn't usualy take too long. + (0..) + .find_map(|_| { + FieldPrio2Draft09::get_decoded(&rng.gen::<[_; 4]>()).ok() + }) + .unwrap() + }) + .collect::>(), + )), + ] + .map(Some) + .into_iter() + .chain([None]); + for data in test_data { + let this = PutRequest { + agg_job_id: AggregationJobId(rng.gen()), + agg_share_delta: DapAggregateShare { + report_count: rng.gen(), + min_time: rng.gen(), + max_time: rng.gen(), + checksum: rng.gen(), + data, + }, + }; + let other = PutRequest::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } + } + } + + #[test] + fn serialization_deserialization_round_trip() { + let mut rng = thread_rng(); + for len in 0..20 { + let test_data = [ + VdafAggregateShare::Field64(AggregateShare::from( + (0..len) + .map(|_| { + Field64::get_decoded(&rng.gen::<[_; Field64::ENCODED_SIZE]>()).unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field128(AggregateShare::from( + (0..len) + .map(|_| { + Field128::get_decoded(&rng.gen::<[_; Field128::ENCODED_SIZE]>()) + .unwrap() + }) + .collect::>(), + )), + VdafAggregateShare::Field32(AggregateShare::from( + (0..len) + .map(|_| { + // idk how to consistently generate a valid FieldPrio2 value, so I just + // retry until I hit a valid one. Doesn't usualy take too long. + (0..) + .find_map(|_| FieldPrio2::get_decoded(&rng.gen::<[_; 4]>()).ok()) + .unwrap() + }) + .collect::>(), + )), + ] + .map(Some) + .into_iter() + .chain([None]); + for data in test_data { + let this = PutRequest { + agg_job_id: AggregationJobId(rng.gen()), + agg_share_delta: DapAggregateShare { + report_count: rng.gen(), + min_time: rng.gen(), + max_time: rng.gen(), + checksum: rng.gen(), + data, + }, + }; + let other = PutRequest::decode_from_bytes(&this.encode_to_bytes()).unwrap(); + assert_eq!(this, other); + } + } + } +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp index cbbb64839..97f2e987e 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.capnp @@ -5,7 +5,7 @@ using Base = import "../../capnproto/base.capnp"; -struct NewJobRequest { +struct NewJobRequest @0xdd285ccdbb2cd14e { id @0 :Base.AggregationJobId; aggJobHash @1 :Data; } diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs index 9271f707d..11a0eae78 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/aggregation_job_store.rs @@ -18,7 +18,7 @@ super::define_do_binding! { enum Command { NewJob = "/new-job", - ListJobIds = "/job-ids", + ContainsJob = "/contains", } fn name((version, task_id): (DapVersion, &'n TaskId)) -> ObjectIdFrom { diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs index 8030e9b52..20d9a8425 100644 --- a/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs +++ b/crates/daphne-service-utils/src/durable_requests/bindings/mod.rs @@ -6,8 +6,11 @@ //! //! It also defines types that are used as the body of requests sent to these objects. +pub mod agg_job_response_store; mod aggregate_store; +pub mod aggregate_store_v2; pub mod aggregation_job_store; +pub mod replay_checker; #[cfg(feature = "test-utils")] mod test_state_cleaner; @@ -48,6 +51,8 @@ macro_rules! define_do_binding { fn name($params:tt : $params_ty:ty) -> ObjectIdFrom $name_impl:block ) => { + $(const _: () = assert!(matches!($route.as_bytes().first(), Some(b'/')));)* + #[derive( serde::Serialize, serde::Deserialize, diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp new file mode 100644 index 000000000..590ac1687 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.capnp @@ -0,0 +1,11 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0xaaa529cce40f45d7; + +using Base = import "../../capnproto/base.capnp"; + +struct CheckReplaysFor @0xe1e6a4a1695238ca { + reports @0 :List(Base.ReportId); + aggregationJobId @1 :Base.AggregationJobId; +} diff --git a/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs new file mode 100644 index 000000000..816d24066 --- /dev/null +++ b/crates/daphne-service-utils/src/durable_requests/bindings/replay_checker.rs @@ -0,0 +1,68 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::{ + capnproto::{ + decode_list, encode_list, usize_to_capnp_len, CapnprotoPayloadDecode, + CapnprotoPayloadEncode, + }, + durable_requests::ObjectIdFrom, + replay_checker_capnp::check_replays_for, +}; +use daphne::messages::{AggregationJobId, ReportId, TaskId, Time}; +use serde::{Deserialize, Serialize}; +use std::{borrow::Cow, collections::HashSet}; + +super::define_do_binding! { + const BINDING = "REPLAY_CHECK_STORE"; + + enum Command { + Check = "/check", + } + + fn name((task_id, epoch, shard): (&'n TaskId, Time, usize)) -> ObjectIdFrom { + ObjectIdFrom::Name(format!("replay-checker/{task_id}/epoch/{epoch}/shard/{shard}")) + } +} + +pub struct Request<'s> { + pub report_ids: Cow<'s, [ReportId]>, + pub aggregation_job_id: AggregationJobId, +} + +impl CapnprotoPayloadEncode for Request<'_> { + type Builder<'a> = check_replays_for::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { + report_ids, + aggregation_job_id, + } = self; + encode_list( + report_ids.iter(), + builder + .reborrow() + .init_reports(usize_to_capnp_len(report_ids.len())), + ); + aggregation_job_id.encode_to_builder(builder.init_aggregation_job_id()); + } +} + +impl CapnprotoPayloadDecode for Request<'static> { + type Reader<'a> = check_replays_for::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + report_ids: decode_list::(reader.get_reports()?)?, + aggregation_job_id: <_>::decode_from_reader(reader.get_aggregation_job_id()?)?, + }) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Response { + pub duplicates: HashSet, +} diff --git a/crates/daphne-service-utils/src/lib.rs b/crates/daphne-service-utils/src/lib.rs index cc255864b..dcd9018d0 100644 --- a/crates/daphne-service-utils/src/lib.rs +++ b/crates/daphne-service-utils/src/lib.rs @@ -46,8 +46,42 @@ mod aggregation_job_store_capnp { )); } +#[cfg(feature = "durable_requests")] +mod agg_job_response_store_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/agg_job_response_store_capnp.rs" + )); +} + +#[cfg(feature = "durable_requests")] +mod aggregate_store_v2_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/aggregate_store_v2_capnp.rs" + )); +} + +#[cfg(feature = "durable_requests")] +mod replay_checker_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + include!(concat!( + env!("OUT_DIR"), + "/src/durable_requests/bindings/replay_checker_capnp.rs" + )); +} + #[cfg(feature = "compute-offload")] -mod compute_offload_capnp { +#[doc(hidden)] +pub mod compute_offload_capnp { #![allow(dead_code)] #![allow(clippy::pedantic)] #![allow(clippy::needless_lifetimes)] diff --git a/crates/daphne-worker-test/docker/aggregator.Dockerfile b/crates/daphne-worker-test/docker/aggregator.Dockerfile index fdb1534d5..29cb633e8 100644 --- a/crates/daphne-worker-test/docker/aggregator.Dockerfile +++ b/crates/daphne-worker-test/docker/aggregator.Dockerfile @@ -1,12 +1,13 @@ # Copyright (c) 2025 Cloudflare, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -FROM rust:1.80-bookworm AS builder +FROM rust:1.84.1-bookworm AS builder RUN apt update && apt install -y capnproto clang cmake # Pre-install worker-build and Rust's wasm32 target to speed up our custom build command RUN rustup target add wasm32-unknown-unknown -RUN cargo install worker-build@0.1.1 --locked +RUN echo Ola +RUN cargo install worker-build@0.1.2 --locked # Build the worker. WORKDIR /tmp/dap_test diff --git a/crates/daphne-worker-test/docker/runtests.Dockerfile b/crates/daphne-worker-test/docker/runtests.Dockerfile index 3b8a2dae1..5d2209927 100644 --- a/crates/daphne-worker-test/docker/runtests.Dockerfile +++ b/crates/daphne-worker-test/docker/runtests.Dockerfile @@ -1,7 +1,7 @@ # Copyright (c) 2024 Cloudflare, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -FROM rust:1.80-bookworm +FROM rust:1.84.1-bookworm WORKDIR /tmp/dap_test diff --git a/crates/daphne-worker-test/docker/storage-proxy.Dockerfile b/crates/daphne-worker-test/docker/storage-proxy.Dockerfile index 7f56133c6..804e71629 100644 --- a/crates/daphne-worker-test/docker/storage-proxy.Dockerfile +++ b/crates/daphne-worker-test/docker/storage-proxy.Dockerfile @@ -1,12 +1,13 @@ # Copyright (c) 2024 Cloudflare, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -FROM rust:1.80-bookworm AS builder +FROM rust:1.84.1-bookworm AS builder RUN apt update && apt install -y capnproto clang cmake # Pre-install worker-build and Rust's wasm32 target to speed up our custom build command RUN rustup target add wasm32-unknown-unknown -RUN cargo install worker-build@0.1.1 --locked +RUN echo ola +RUN cargo install worker-build@0.1.2 --locked # Build the storage proxy. WORKDIR /tmp/dap_test diff --git a/crates/daphne-worker-test/src/durable.rs b/crates/daphne-worker-test/src/durable.rs index 6fb72a456..72606f377 100644 --- a/crates/daphne-worker-test/src/durable.rs +++ b/crates/daphne-worker-test/src/durable.rs @@ -11,6 +11,14 @@ instantiate_durable_object! { } } +instantiate_durable_object! { + struct ReplayChecker < durable::ReplayChecker; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} + instantiate_durable_object! { struct AggregationJobStore < durable::AggregationJobStore; @@ -18,3 +26,19 @@ instantiate_durable_object! { daphne_worker::tracing_utils::initialize_tracing(env); } } + +instantiate_durable_object! { + struct AggJobResponseStore < durable::AggJobResponseStore; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} + +instantiate_durable_object! { + struct AggregateStoreV2 < durable::AggregateStoreV2; + + fn init_user_data(_state: State, env: Env) { + daphne_worker::tracing_utils::initialize_tracing(env); + } +} diff --git a/crates/daphne-worker-test/src/lib.rs b/crates/daphne-worker-test/src/lib.rs index a48da5f3d..2d191ed8b 100644 --- a/crates/daphne-worker-test/src/lib.rs +++ b/crates/daphne-worker-test/src/lib.rs @@ -5,7 +5,7 @@ use daphne_worker::{aggregator::App, initialize_tracing}; use futures::stream; use std::convert::Infallible; use tracing::info; -use worker::{event, Env, HttpRequest, ResponseBody}; +use worker::{event, Env, HttpRequest, MessageBatch, ResponseBody}; mod durable; mod utils; @@ -13,6 +13,11 @@ mod utils; #[global_allocator] static CAP: cap::Cap = cap::Cap::new(std::alloc::System, 65_000_000); +fn load_compute_offload_host(env: &worker::Env) -> String { + env.var("COMPUTE_OFFLOAD_HOST") + .map_or_else(|_| "localhost:5000".into(), |t| t.to_string()) +} + #[event(fetch, respond_with_errors)] pub async fn main( req: HttpRequest, @@ -39,9 +44,7 @@ pub async fn main( daphne_worker::storage_proxy::handle_request(req, env, ®istry).await } Some("aggregator") => { - let host = env - .var("COMPUTE_OFFLOAD_HOST") - .map_or_else(|_| "localhost:5000".into(), |t| t.to_string()); + let host = load_compute_offload_host(&env); daphne_worker::aggregator::handle_dap_request( App::new(env, ®istry, None, Box::new(ComputeOffload { host })).unwrap(), @@ -96,3 +99,16 @@ impl daphne_worker::aggregator::ComputeOffload for ComputeOffload { .unwrap()) } } + +#[event(queue)] +pub async fn queue( + batch: MessageBatch<()>, + env: worker::Env, + _ctx: worker::Context, +) -> worker::Result<()> { + let registry = prometheus::Registry::new(); + let host = load_compute_offload_host(&env); + let app = App::new(env, ®istry, None, Box::new(ComputeOffload { host })).unwrap(); + daphne_worker::aggregator::queues::async_aggregate(app, batch).await; + Ok(()) +} diff --git a/crates/daphne-worker-test/wrangler.aggregator.toml b/crates/daphne-worker-test/wrangler.aggregator.toml index a969be856..00d311001 100644 --- a/crates/daphne-worker-test/wrangler.aggregator.toml +++ b/crates/daphne-worker-test/wrangler.aggregator.toml @@ -68,9 +68,19 @@ ip = "0.0.0.0" bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, + { name = "REPLAY_CHECK_STORE", class_name = "ReplayChecker" }, { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_JOB_RESULT_STORE", class_name = "AggJobResponseStore" }, + { name = "AGGREGATE_STORE", class_name = "AggregateStoreV2" }, ] +[[env.helper.queues.producers]] +queue = "async-aggregation-queue" +binding = "ASYNC_AGGREGATION_QUEUE" + +[[env.helper.queues.consumers]] +queue = "async-aggregation-queue" +max_retries = 10 [[env.helper.kv_namespaces]] binding = "DAP_CONFIG" @@ -131,7 +141,7 @@ public_key = "047dab625e0d269abcc28c611bebf5a60987ddf7e23df0e0aa343e5774ad81a1d0 bindings = [ { name = "DAP_AGGREGATE_STORE", class_name = "AggregateStore" }, { name = "DAP_TEST_STATE_CLEANER", class_name = "TestStateCleaner" }, - { name = "AGGREGATION_JOB_STORE", class_name = "AggregationJobStore" }, + { name = "AGGREGATE_STORE", class_name = "AggregateStoreV2" }, ] [[env.leader.kv_namespaces]] @@ -159,4 +169,6 @@ new_classes = [ "AggregateStore", "GarbageCollector", "AggregationJobStore", + "AggJobResponseStore", + "AggregateStoreV2", ] diff --git a/crates/daphne-worker/Cargo.toml b/crates/daphne-worker/Cargo.toml index da77cbca5..d0014848b 100644 --- a/crates/daphne-worker/Cargo.toml +++ b/crates/daphne-worker/Cargo.toml @@ -19,6 +19,7 @@ crate-type = ["cdylib", "rlib"] async-trait = { workspace = true } axum-extra = { workspace = true, features = ["typed-header"] } bytes.workspace = true +capnp = { workspace = true } chrono = { workspace = true, default-features = false, features = ["clock", "wasmbind"] } constcat.workspace = true daphne = { path = "../daphne", features = ["prometheus"] } @@ -49,7 +50,8 @@ tracing-core.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter", "json"]} tracing.workspace = true url.workspace = true -worker.workspace = true +wasm-bindgen.workspace = true +worker = { workspace = true , features = ["http", "queue"] } [dependencies.axum] workspace = true @@ -67,6 +69,9 @@ reqwest.workspace = true # used in doc tests tokio.workspace = true webpki.workspace = true +[build-dependencies] +capnpc = { workspace = true } + [features] test-utils = ["daphne-service-utils/test-utils"] diff --git a/crates/daphne-worker/build.rs b/crates/daphne-worker/build.rs new file mode 100644 index 000000000..77c603c87 --- /dev/null +++ b/crates/daphne-worker/build.rs @@ -0,0 +1,10 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +fn main() { + ::capnpc::CompilerCommand::new() + .import_path("../daphne-service-utils/src") + .file("./src/aggregator/queues/queue_messages.capnp") + .run() + .expect("compiling schema"); +} diff --git a/crates/daphne-worker/src/aggregator/config.rs b/crates/daphne-worker/src/aggregator/config.rs index d07498c04..e4435cc4a 100644 --- a/crates/daphne-worker/src/aggregator/config.rs +++ b/crates/daphne-worker/src/aggregator/config.rs @@ -194,10 +194,14 @@ pub fn load_config_from_env(env: &worker::Env) -> Result(SERVICE_CONFIG) + // due to a bug where JsValue::UNDEFINED is not considered nullish in serde-wasm-bindgen we + // have to deserialize to serde_json::Value first and DaphneServiceConfig later. + let config = env + .object_var::(SERVICE_CONFIG) .map_err(|e| fatal_error!(err = ?e, "failed to load SERVICE_CONFIG variable"))?; + let mut config = serde_json::from_value::(config).unwrap(); + if config.taskprov.is_some() { tracing::warn!("taskprov secrets are defined in plain text. Prefer using wrangler secrets"); } else if matches!(env.var(taskprov_secrets::ENABLED), Ok(s) if s.to_string() == "true") { @@ -236,9 +240,16 @@ mod taskprov_secrets { pub fn load(env: &worker::Env) -> Result { Ok(super::TaskprovConfig { - hpke_collector_config: env.object_var(TASKPROV_HPKE_COLLECTOR_CONFIG).map_err( - |e| fatal_error!(err = ?e, "failed to load TASKPROV_HPKE_COLLECTOR_CONFIG"), - )?, + hpke_collector_config: env + .object_var::(TASKPROV_HPKE_COLLECTOR_CONFIG) + .map_err( + |e| fatal_error!(err = ?e, "failed to load TASKPROV_HPKE_COLLECTOR_CONFIG"), + ) + .and_then(|v| { + serde_json::from_value(v).map_err( + |e| fatal_error!(err = ?e, "failed to load TASKPROV_HPKE_COLLECTOR_CONFIG"), + ) + })?, vdaf_verify_key_init: { let key = VDAF_VERIFY_KEY_INIT; hex::decode( diff --git a/crates/daphne-worker/src/aggregator/mod.rs b/crates/daphne-worker/src/aggregator/mod.rs index 1e0254112..063c7e682 100644 --- a/crates/daphne-worker/src/aggregator/mod.rs +++ b/crates/daphne-worker/src/aggregator/mod.rs @@ -3,6 +3,7 @@ mod config; mod metrics; +pub mod queues; mod roles; mod router; @@ -31,6 +32,7 @@ use router::DaphneService; use std::sync::{Arc, LazyLock, Mutex}; use worker::send::SendWrapper; +use queues::Queue; pub use router::handle_dap_request; #[async_trait::async_trait(?Send)] @@ -194,4 +196,12 @@ impl App { fn bearer_tokens(&self) -> BearerTokens<'_> { BearerTokens::from(Kv::new(&self.env, &self.kv_state)) } + + fn async_aggregation_queue(&self) -> Queue { + Queue::from( + self.env + .get_binding::("ASYNC_AGGREGATION_QUEUE") + .unwrap(), + ) + } } diff --git a/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs b/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs new file mode 100644 index 000000000..939ff2155 --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/async_aggregator.rs @@ -0,0 +1,342 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use crate::{ + aggregator::App, + elapsed, queue_messages_capnp, + storage::{self, Do}, +}; +use daphne::{ + messages::{ + taskprov::TaskprovAdvertisement, AggregationJobId, PartialBatchSelector, ReportId, + ReportMetadata, TaskId, Time, + }, + roles::helper::handle_agg_job::ToInitializedReportsTransition, + DapVersion, +}; +use daphne_service_utils::{ + capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt, CapnprotoPayloadEncode}, + compute_offload, + durable_requests::bindings::{ + agg_job_response_store, aggregate_store_v2, + replay_checker::{self, Command}, + }, +}; +use futures::{stream::FuturesUnordered, StreamExt, TryStreamExt}; +use prio::codec::ParameterizedDecode; +use std::{ + collections::{HashMap, HashSet}, + future::Future, + num::NonZeroUsize, + time::Duration, +}; +use worker::{MessageBatch, MessageExt, RawMessage}; + +fn deserialize(message: &RawMessage) -> worker::Result { + let buf: worker::js_sys::Uint8Array = message.body().into(); + T::decode_from_bytes(&buf.to_vec()).map_err(|e| worker::Error::RustError(e.to_string())) +} + +pub struct AsyncAggregationMessage<'s> { + pub version: DapVersion, + pub part_batch_sel: PartialBatchSelector, + pub agg_job_id: AggregationJobId, + pub initialize_reports: compute_offload::InitializeReports<'s>, + pub taskprov_advertisement: Option, +} + +impl CapnprotoPayloadEncode for AsyncAggregationMessage<'_> { + type Builder<'a> = queue_messages_capnp::async_aggregation_message::Builder<'a>; + + fn encode_to_builder(&self, mut builder: Self::Builder<'_>) { + let Self { + version, + part_batch_sel, + agg_job_id, + initialize_reports, + taskprov_advertisement, + } = self; + builder.set_version((*version).into()); + part_batch_sel.encode_to_builder(builder.reborrow().init_partial_batch_selector()); + agg_job_id.encode_to_builder(builder.reborrow().init_aggregation_job_id()); + initialize_reports.encode_to_builder(builder.reborrow().init_initialize_reports()); + match taskprov_advertisement.as_deref() { + Some(ta) => builder + .init_taskprov_advertisement() + .set_some(ta.into()) + .unwrap(), + None => builder.init_taskprov_advertisement().set_none(()), + } + } +} + +impl CapnprotoPayloadDecode for AsyncAggregationMessage<'static> { + type Reader<'a> = queue_messages_capnp::async_aggregation_message::Reader<'a>; + + fn decode_from_reader(reader: Self::Reader<'_>) -> capnp::Result + where + Self: Sized, + { + Ok(Self { + version: reader.get_version()?.into(), + agg_job_id: <_>::decode_from_reader(reader.get_aggregation_job_id()?)?, + part_batch_sel: <_>::decode_from_reader(reader.get_partial_batch_selector()?)?, + initialize_reports: <_>::decode_from_reader(reader.get_initialize_reports()?)?, + taskprov_advertisement: match reader.get_taskprov_advertisement()?.which()? { + queue_messages_capnp::option::Which::None(()) => None, + queue_messages_capnp::option::Which::Some(text) => Some(text?.to_string()?), + }, + }) + } +} + +async fn shard_reports( + durable: Do<'_>, + task_id: &TaskId, + time_precision: Time, + agg_job_id: AggregationJobId, + reports: impl Iterator, +) -> Result, storage::Error> { + let mut shards = HashMap::<_, Vec<_>>::new(); + for r in reports { + let epoch = r.time - (r.time % time_precision); + let shard = r.id.shard(NonZeroUsize::new(1024).unwrap()); + shards.entry((epoch, shard)).or_default().push(r.id); + } + + futures::stream::iter(shards) + .map(|((epoch, shard), report_ids)| async move { + durable + .with_retry() + .request(Command::Check, (task_id, epoch, shard)) + .encode(&replay_checker::Request { + report_ids: report_ids.into(), + aggregation_job_id: agg_job_id, + }) + .send::() + .await + .map(|r| r.duplicates) + }) + .buffer_unordered(6) + .try_fold(HashSet::new(), |mut acc, dups| async move { + acc.extend(dups); + Ok(acc) + }) + .await +} + +macro_rules! bail { + (retry $m:ident, err = $error:expr, $msg:literal) => {{ + $m.retry(); + bail!(err = $error, $msg) + }}; + (err = $error:expr, $msg:literal) => {{ + tracing::error!(error = ?$error, $msg); + return; + }} +} + +/// Perform an aggregation job. +/// +/// ## Note +/// There is a worst case scenario that this handler can't deal with. +/// +/// Messages will be replayed if they fail, but the workers runtime will eventually give up on a +/// message after it's been retried a bunch of times, in that case the helper will never respond +/// positively to the poll request, possibly leaving the leader in an infinite loop state. The +/// leader can resubmit the work as many times as it wants to get out of this situation, but +/// implementers of the leader must be made aware of this. +// +// ----- +// +// All of the IO in this function is idempotent. They can be spotted by looking at the `.await` +// expressions. The explanation is as follows: +// 1. Getting a task config. No writes performed. +// 2. Initializing the reports. Stateless. +// 3. The same aggregation job may replay it's own reports, this means replay checking is +// idempotent. See the [ReplayChecker] durable object for more details. +// 4. Storing the aggregate share does not merge with any other aggregate share and simply +// replaces the previous one, which will be identical. +// 5. Storing the aggregate response simply overwrites the previous response, which will be +// identical. +// +#[tracing::instrument(skip_all, fields(?version, ?part_batch_sel, ?agg_job_id))] +pub async fn async_aggregate_one( + app: &App, + message: &RawMessage, + AsyncAggregationMessage { + version, + part_batch_sel, + agg_job_id, + initialize_reports, + taskprov_advertisement, + }: AsyncAggregationMessage<'_>, +) { + let task_id = &initialize_reports.task_id; + tracing::debug!("getting task config"); + let taskprov_advertisement = match taskprov_advertisement + .map(|s| TaskprovAdvertisement::parse_taskprov_advertisement(&s, task_id, version)) + .transpose() + { + Ok(taskprov_advertisement) => taskprov_advertisement, + Err(e) => bail!(err = e, "taskprov advertisement was illformed"), + }; + + // 1. + let (task_config, get_task_config_duration) = time_it(daphne::roles::resolve_task_config( + app, + &daphne::DapRequestMeta { + version, + media_type: None, + task_id: *task_id, + taskprov_advertisement, + }, + )) + .await; + + let task_config = match task_config { + Ok(t) => t, + Err(e) => bail!(retry message, err = e, "not such task config"), + }; + + let agg_param = match daphne::DapAggregationParam::get_decoded_with_param( + &task_config.vdaf, + &initialize_reports.agg_param, + ) { + Ok(param) => param, + Err(e) => bail!(err = e, "dap aggregation parameter was illformed"), + }; + + // 2. + tracing::debug!("initializing reports"); + let (initialized_reports, initialize_reports_duration) = match time_it( + app.compute_offload + .compute::<_, compute_offload::InitializedReports>( + "/compute_offload/initialize_reports", + &initialize_reports, + ), + ) + .await + { + (Ok(init), duration) => (init, duration), + (Err(e), _) => bail!(retry message, err = e, "failed to initialize reports"), + }; + + let time_precision = task_config.time_precision; + let state_machine = ToInitializedReportsTransition { + task_id: *task_id, + part_batch_sel, + task_config, + } + .with_initialized_reports(agg_param, initialized_reports.reports); + + // 3. + tracing::debug!("checking replays"); + let (state_machine, check_for_replays_duration) = + time_it(state_machine.check_for_replays(|report_ids| { + let report_ids = report_ids.cloned().collect::>(); + shard_reports( + app.durable(), + &initialize_reports.task_id, + time_precision, + agg_job_id, + report_ids.into_iter(), + ) + })) + .await; + + let state_machine = match state_machine { + Ok(st) => st, + Err(e) => bail!(retry message, err = e, "failed to check replays"), + }; + + let (span, agg_job_response) = match state_machine.finish() { + Ok(output) => output, + // this error is always caused by a bug in the code, it's not recoverable + Err(e) => bail!(err = e, "failed to finish aggregation"), + }; + + let ((), store_aggregate_share_duration) = time_it(async { + for (bucket, (share, _)) in span { + let request = aggregate_store_v2::PutRequest { + agg_job_id, + agg_share_delta: share, + }; + // 4. + tracing::debug!("storing aggregate shares"); + let response = app + .durable() + .with_retry() + .request( + aggregate_store_v2::Command::Put, + (version, task_id, &bucket), + ) + .encode(&request) + .send::<()>() + .await; + match response { + Ok(()) => {} + Err(e) => bail!(retry message, err = e, "failed to store aggregate share"), + } + } + }) + .await; + + // 5. + tracing::debug!("storing aggregate job response"); + let (result, store_aggregate_job_response_duration) = time_it( + app.durable() + .with_retry() + .request( + agg_job_response_store::Command::Put, + (version, task_id, &agg_job_id), + ) + .encode(&agg_job_response) + .send::<()>(), + ) + .await; + + match result { + Ok(()) => {} + Err(e) => bail!(retry message, err = e, "failed to store aggregation response"), + } + + tracing::info!( + ?get_task_config_duration, + ?initialize_reports_duration, + ?check_for_replays_duration, + ?store_aggregate_share_duration, + ?store_aggregate_job_response_duration, + "successfully aggregated" + ); +} + +#[tracing::instrument(skip_all)] +pub async fn async_aggregate(app: App, message_batch: MessageBatch<()>) { + tracing::info!( + count = message_batch.raw_iter().count(), + "handling message batch" + ); + let app = &app; + message_batch + .raw_iter() + .map(|m| async move { + let message = match deserialize(&m) { + Ok(m) => m, + Err(e) => bail!(err = e, "failed to deserialize replay queue message"), + }; + async_aggregate_one(app, &m, message).await + }) + .collect::>() + .collect::<()>() + .await; +} + +async fn time_it(future: F) -> (F::Output, Duration) +where + F: Future, +{ + let now = worker::Date::now(); + let result = future.await; + (result, elapsed(&now)) +} diff --git a/crates/daphne-worker/src/aggregator/queues/mod.rs b/crates/daphne-worker/src/aggregator/queues/mod.rs new file mode 100644 index 000000000..904f7ddc2 --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/mod.rs @@ -0,0 +1,39 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +mod async_aggregator; + +pub use async_aggregator::{async_aggregate, AsyncAggregationMessage}; +use daphne_service_utils::capnproto::{CapnprotoPayloadEncode, CapnprotoPayloadEncodeExt as _}; +use std::marker::PhantomData; +use worker::RawMessageBuilder; + +pub struct Queue { + queue: worker::Queue, + _message_type: PhantomData, +} + +impl Queue { + #[tracing::instrument(skip_all, fields(message = std::any::type_name::()))] + pub async fn send(&self, message: &T) -> worker::Result<()> { + tracing::info!("submiting queue message"); + let bytes = worker::js_sys::Uint8Array::from(message.encode_to_bytes().as_slice()); + self.queue + .send_raw( + RawMessageBuilder::new(bytes.into()) + .build_with_content_type(worker::QueueContentType::V8), + ) + .await?; + + Ok(()) + } +} + +impl From for Queue { + fn from(queue: worker::Queue) -> Self { + Self { + queue, + _message_type: PhantomData, + } + } +} diff --git a/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp b/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp new file mode 100644 index 000000000..dde73ff91 --- /dev/null +++ b/crates/daphne-worker/src/aggregator/queues/queue_messages.capnp @@ -0,0 +1,23 @@ +# Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +@0x8240fbeac47031a3; + +using Base = import "/capnproto/base.capnp"; +using ComputeOffload = import "/compute_offload/compute_offload.capnp"; + +struct Option(T) { + union { + none @0 :Void; + some @1 :T; + } +} + +struct AsyncAggregationMessage @0xbe3d785aff491226 { + version @0 :Base.DapVersion; + reports @1 :List(Base.ReportId); + aggregationJobId @2 :Base.AggregationJobId; + partialBatchSelector @3 :Base.PartialBatchSelector; + initializeReports @4 :ComputeOffload.InitializeReports; + taskprovAdvertisement @5 :Option(Text); +} diff --git a/crates/daphne-worker/src/aggregator/roles/aggregator.rs b/crates/daphne-worker/src/aggregator/roles/aggregator.rs index 1d9124e4c..864040649 100644 --- a/crates/daphne-worker/src/aggregator/roles/aggregator.rs +++ b/crates/daphne-worker/src/aggregator/roles/aggregator.rs @@ -21,9 +21,10 @@ use daphne::{ DapVersion, }; use daphne_service_utils::durable_requests::bindings::{ - self, AggregateStoreMergeOptions, AggregateStoreMergeReq, AggregateStoreMergeResp, + self, aggregate_store_v2, AggregateStoreMergeOptions, AggregateStoreMergeReq, + AggregateStoreMergeResp, }; -use futures::{future::try_join_all, StreamExt as _, TryFutureExt as _, TryStreamExt as _}; +use futures::{future::try_join_all, StreamExt, TryFutureExt as _, TryStreamExt}; use mappable_rc::Marc; use std::{num::NonZeroUsize, ops::Range}; use worker::send::SendFuture; @@ -75,72 +76,31 @@ impl DapAggregator for App { .await } + // this implementation is hardcoded to for the helper #[tracing::instrument(skip(self))] async fn get_agg_share( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { - let task_config = self - .get_task_config_for(task_id) - .await? - .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { - task_id: *task_id, - }))?; - - let durable = self.durable(); - let mut requests = Vec::new(); - for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { - requests.push( - durable - .request( - bindings::AggregateStore::Get, - (task_config.as_ref().version, task_id, &bucket), - ) - .send(), - ); - } - let responses: Vec = try_join_all(requests) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"))?; - let mut agg_share = DapAggregateShare::default(); - for agg_share_delta in responses { - agg_share.merge(agg_share_delta)?; + match version { + DapVersion::Latest => self.get_agg_share_draft_latest(task_id, batch_sel).await, + DapVersion::Draft09 => self.get_agg_share_draft_09(task_id, batch_sel).await, } - - Ok(agg_share) } #[tracing::instrument(skip(self))] async fn mark_collected( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result<(), DapError> { - let task_config = self - .get_task_config_for(task_id) - .await? - .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { - task_id: *task_id, - }))?; - - let durable = self.durable(); - let mut requests = Vec::new(); - for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { - requests.push( - durable - .request( - bindings::AggregateStore::MarkCollected, - (task_config.as_ref().version, task_id, &bucket), - ) - .send::<()>(), - ); + match version { + DapVersion::Draft09 => self.mark_collected_draft09(task_id, batch_sel).await, + DapVersion::Latest => self.mark_collected_draft_latest(task_id, batch_sel).await, } - - try_join_all(requests) - .await - .map_err(|e| fatal_error!(err = ?e, "failed to mark agg shares as collected"))?; - Ok(()) } async fn get_global_config(&self) -> Result { @@ -251,89 +211,29 @@ impl DapAggregator for App { async fn is_batch_overlapping( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { - let task_config = self - .get_task_config_for(task_id) - .await? - .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { - task_id: *task_id, - }))?; - - // Check whether the request overlaps with previous requests. This is done by - // checking the AggregateStore and seeing whether it requests for aggregate - // shares that have already been marked collected. - let durable = self.durable(); - Ok( - futures::stream::iter(task_config.batch_span_for_sel(batch_sel)?) - .map(|bucket| { - durable - .request( - bindings::AggregateStore::CheckCollected, - (task_config.as_ref().version, task_id, &bucket), - ) - .send() - }) - .buffer_unordered(usize::MAX) - .try_any(std::future::ready) - .await - .map_err( - |e| fatal_error!(err = ?e, "failed to check if agg shares are collected"), - )?, - ) + match version { + DapVersion::Draft09 => self.is_batch_overlapping_draft09(task_id, batch_sel).await, + DapVersion::Latest => { + self.is_batch_overlapping_draft_latest(task_id, batch_sel) + .await + } + } } - async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { - let task_config = self - .get_task_config_for(task_id) - .await? - .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { - task_id: *task_id, - }))?; - let version = task_config.as_ref().version; - - let agg_span = task_config.batch_span_for_sel(&BatchSelector::LeaderSelectedByBatchId { - batch_id: *batch_id, - })?; - - futures::stream::iter(agg_span) - .map(|bucket| async move { - let durable = self.durable(); - let params = (version, task_id, &bucket); - - let get_report_count = || { - durable - .request(bindings::AggregateStore::ReportCount, params) - .send::() - }; - - // TODO: remove this after the worker has this feature deployed. - let backwards_compat_get_report_count = || { - durable - .request(bindings::AggregateStore::Get, params) - .send::() - .map_ok(|r| r.report_count) - }; - - let count = get_report_count() - .or_else(|_| backwards_compat_get_report_count()) - .await - .map_err(|e| { - fatal_error!( - err = ?e, - params = ?params, - "failed fetching report count of an agg share" - ) - })?; - Ok(count > 0) - }) - .buffer_unordered(usize::MAX) - .collect::>() - .await - .into_iter() - .reduce(|a, b| Ok(a? || b?)) - .unwrap_or(Ok(false)) + async fn batch_exists( + &self, + version: DapVersion, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result { + match version { + DapVersion::Draft09 => self.batch_exists_draft09(task_id, batch_id).await, + DapVersion::Latest => self.batch_exists_draft_latest(task_id, batch_id).await, + } } fn metrics(&self) -> &dyn DaphneMetrics { @@ -434,3 +334,291 @@ impl hpke::HpkeProvider for App { )) } } + +impl App { + async fn get_agg_share_draft_09( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + tracing::debug!("gathering agg share"); + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + let mut requests = Vec::new(); + for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { + requests.push( + durable + .request( + bindings::AggregateStore::Get, + (task_config.as_ref().version, task_id, &bucket), + ) + .send(), + ); + } + let responses: Vec = try_join_all(requests) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"))?; + let mut agg_share = DapAggregateShare::default(); + for agg_share_delta in responses { + agg_share.merge(agg_share_delta)?; + } + + Ok(agg_share) + } + + async fn mark_collected_draft09( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result<(), DapError> { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + let mut requests = Vec::new(); + for bucket in task_config.as_ref().batch_span_for_sel(batch_sel)? { + requests.push( + durable + .request( + bindings::AggregateStore::MarkCollected, + (task_config.as_ref().version, task_id, &bucket), + ) + .send::<()>(), + ); + } + + try_join_all(requests) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to mark agg shares as collected"))?; + Ok(()) + } + + async fn is_batch_overlapping_draft09( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + // Check whether the request overlaps with previous requests. This is done by + // checking the AggregateStore and seeing whether it requests for aggregate + // shares that have already been marked collected. + let durable = self.durable(); + futures::stream::iter(task_config.batch_span_for_sel(batch_sel)?) + .map(|bucket| { + durable + .request( + bindings::AggregateStore::CheckCollected, + (task_config.as_ref().version, task_id, &bucket), + ) + .send() + }) + .buffer_unordered(usize::MAX) + .try_any(std::future::ready) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to check if agg shares are collected")) + } + + async fn batch_exists_draft09( + &self, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + let version = task_config.as_ref().version; + + let agg_span = task_config.batch_span_for_sel(&BatchSelector::LeaderSelectedByBatchId { + batch_id: *batch_id, + })?; + + futures::stream::iter(agg_span) + .map(|bucket| async move { + let durable = self.durable(); + let params = (version, task_id, &bucket); + + let get_report_count = || { + durable + .request(bindings::AggregateStore::ReportCount, params) + .send::() + }; + + // TODO: remove this after the worker has this feature deployed. + let backwards_compat_get_report_count = || { + durable + .request(bindings::AggregateStore::Get, params) + .send::() + .map_ok(|r| r.report_count) + }; + + let count = get_report_count() + .or_else(|_| backwards_compat_get_report_count()) + .await + .map_err(|e| { + fatal_error!( + err = ?e, + params = ?params, + "failed fetching report count of an agg share" + ) + })?; + Ok(count > 0) + }) + .buffer_unordered(usize::MAX) + .collect::>() + .await + .into_iter() + .reduce(|a, b| Ok(a? || b?)) + .unwrap_or(Ok(false)) + } + + async fn get_agg_share_draft_latest( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + let agg_share = futures::stream::iter(task_config.as_ref().batch_span_for_sel(batch_sel)?) + .map(|bucket| { + durable + .request( + aggregate_store_v2::Command::Get, + (task_config.as_ref().version, task_id, &bucket), + ) + .send() + .map_err( + |e| fatal_error!(err = ?e, "failed to get agg shares from durable objects"), + ) + }) + .buffer_unordered(6) + .try_fold(DapAggregateShare::default(), |mut acc, share| async move { + acc.merge(share).map(|()| acc) + }) + .await?; + + Ok(agg_share) + } + + async fn mark_collected_draft_latest( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result<(), DapError> { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + let durable = self.durable(); + futures::stream::iter(task_config.as_ref().batch_span_for_sel(batch_sel)?) + .map(|bucket| { + durable + .request( + aggregate_store_v2::Command::MarkCollected, + (task_config.as_ref().version, task_id, &bucket), + ) + .send::<()>() + .map_err(|e| fatal_error!(err = ?e, "failed to mark agg shares as collected")) + }) + .buffer_unordered(6) + .try_collect::<()>() + .await?; + + Ok(()) + } + + async fn is_batch_overlapping_draft_latest( + &self, + task_id: &TaskId, + batch_sel: &BatchSelector, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + + // Check whether the request overlaps with previous requests. This is done by + // checking the AggregateStore and seeing whether it requests for aggregate + // shares that have already been marked collected. + let durable = self.durable(); + futures::stream::iter(task_config.batch_span_for_sel(batch_sel)?) + .map(|bucket| { + durable + .request( + aggregate_store_v2::Command::CheckCollected, + (task_config.as_ref().version, task_id, &bucket), + ) + .send() + }) + .buffer_unordered(usize::MAX) + .try_any(std::future::ready) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to check if agg shares are collected")) + } + + async fn batch_exists_draft_latest( + &self, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result { + let task_config = self + .get_task_config_for(task_id) + .await? + .ok_or(DapError::Abort(DapAbort::UnrecognizedTask { + task_id: *task_id, + }))?; + let version = task_config.as_ref().version; + + let agg_span = task_config.batch_span_for_sel(&BatchSelector::LeaderSelectedByBatchId { + batch_id: *batch_id, + })?; + + futures::stream::iter(agg_span) + .map(|bucket| async move { + let params = (version, task_id, &bucket); + self.durable() + .request(aggregate_store_v2::Command::AggregateShareCount, params) + .send::() + .await + .map_err(|e| { + fatal_error!( + err = ?e, + params = ?params, + "failed fetching report count of an agg share" + ) + }) + }) + .buffer_unordered(usize::MAX) + .try_all(|count| std::future::ready(count > 0)) + .await + } +} diff --git a/crates/daphne-worker/src/aggregator/roles/helper.rs b/crates/daphne-worker/src/aggregator/roles/helper.rs index 26e284b49..0b27aa383 100644 --- a/crates/daphne-worker/src/aggregator/roles/helper.rs +++ b/crates/daphne-worker/src/aggregator/roles/helper.rs @@ -5,11 +5,14 @@ use crate::aggregator::App; use daphne::{ error::DapAbort, fatal_error, - messages::{AggregationJobId, TaskId}, + messages::{AggregationJobId, AggregationJobResp, TaskId}, + protocol::ReadyAggregationJobResp, roles::{helper::AggregationJobRequestHash, DapHelper}, DapError, DapVersion, }; -use daphne_service_utils::durable_requests::bindings::aggregation_job_store; +use daphne_service_utils::durable_requests::bindings::{ + agg_job_response_store, aggregation_job_store, +}; use std::borrow::Cow; #[axum::async_trait] @@ -41,4 +44,46 @@ impl DapHelper for App { ), } } + + async fn poll_aggregated( + &self, + version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result { + let valid_agg_job_id = self + .durable() + .with_retry() + .request( + aggregation_job_store::Command::ContainsJob, + (version, task_id), + ) + .encode(agg_job_id) + .send::() + .await + .map_err(|e| fatal_error!(err = ?e, "failed to query the validity of the aggregation job id"))?; + + if !valid_agg_job_id { + return Err(DapError::Abort(DapAbort::UnrecognizedAggregationJob { + task_id: *task_id, + agg_job_id: *agg_job_id, + })); + } + + let response = self + .durable() + .with_retry() + .request( + agg_job_response_store::Command::Get, + (version, task_id, agg_job_id), + ) + .send::>() + .await + .map_err(|e| fatal_error!(err = ?e, "failed to poll for aggregation job response"))?; + + match response { + Some(ready) => Ok(ready.into()), + None => Ok(AggregationJobResp::Processing), + } + } } diff --git a/crates/daphne-worker/src/aggregator/roles/leader.rs b/crates/daphne-worker/src/aggregator/roles/leader.rs index 0a6e48fa4..8f72f61ce 100644 --- a/crates/daphne-worker/src/aggregator/roles/leader.rs +++ b/crates/daphne-worker/src/aggregator/roles/leader.rs @@ -125,6 +125,10 @@ impl DapLeader for App { { self.send_http(meta, Method::PUT, url, payload).await } + + async fn send_http_get(&self, meta: DapRequestMeta, url: Url) -> Result { + self.send_http(meta, Method::GET, url, ()).await + } } impl App { @@ -207,7 +211,7 @@ impl App { let req_builder = self .http - .request(method, url.clone()) + .request(method.clone(), url.clone()) .body( payload .get_encoded_with_param(&meta.version) @@ -220,7 +224,10 @@ impl App { .send() .await .map_err(|e| fatal_error!(err = ?e, "failed to send request to the helper"))?; - info!("request to {} completed in {:?}", url, elapsed(&start)); + info!( + "{method} request to {url} completed in {:?}", + elapsed(&start) + ); let status = reqwest_resp.status(); if status.is_success() { diff --git a/crates/daphne-worker/src/aggregator/router/extractor.rs b/crates/daphne-worker/src/aggregator/router/extractor.rs index 9ea29fe5f..9e4055f43 100644 --- a/crates/daphne-worker/src/aggregator/router/extractor.rs +++ b/crates/daphne-worker/src/aggregator/router/extractor.rs @@ -11,7 +11,7 @@ use daphne::{ error::DapAbort, fatal_error, messages::{ - request::{CollectionPollReq, RequestBody}, + request::{CollectionPollReq, PollAggregationJob, RequestBody}, taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq, CollectionReq, Report, TaskId, }, @@ -68,6 +68,12 @@ impl DecodeFromDapHttpBody for HashedAggregationJobReq { } } +impl DecodeFromDapHttpBody for PollAggregationJob { + fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { + Ok(Self) + } +} + /// Using `()` ignores the body of a request. impl DecodeFromDapHttpBody for CollectionPollReq { fn decode_from_http_body(_bytes: Bytes, _meta: &DapRequestMeta) -> Result { diff --git a/crates/daphne-worker/src/aggregator/router/helper.rs b/crates/daphne-worker/src/aggregator/router/helper.rs index 848739a8e..3483c4fc5 100644 --- a/crates/daphne-worker/src/aggregator/router/helper.rs +++ b/crates/daphne-worker/src/aggregator/router/helper.rs @@ -5,7 +5,7 @@ use super::{ super::roles::fetch_replay_protection_override, extractor::dap_sender::FROM_LEADER, App, AxumDapResponse, DapRequestExtractor, DaphneService, }; -use crate::elapsed; +use crate::{aggregator::queues, elapsed}; use axum::{ extract::State, routing::{post, put}, @@ -14,12 +14,12 @@ use daphne::{ error::DapAbort, fatal_error, hpke::HpkeProvider, - messages::AggregateShareReq, + messages::{request::PollAggregationJob, AggregateShareReq, AggregationJobResp}, roles::{ helper::{self, HashedAggregationJobReq}, DapAggregator, DapHelper, }, - DapAggregationParam, DapError, DapResponse, + DapAggregationParam, DapError, DapResponse, DapVersion, }; use daphne_service_utils::compute_offload; use http::StatusCode; @@ -30,7 +30,7 @@ pub(super) fn add_helper_routes(router: super::Router) -> super::Router) -> super::Router>, + req: DapRequestExtractor, +) -> AxumDapResponse { + match req.0.version { + DapVersion::Draft09 => agg_job_draft9(state, req).await, + DapVersion::Latest => agg_job_draft_latest(state, req).await, + } +} + +async fn agg_job_draft_latest( + State(app): State>, + DapRequestExtractor(req): DapRequestExtractor, +) -> AxumDapResponse { + let now = worker::Date::now(); + let version = req.version; + + let queue_result = async { + let (transition, req) = helper::handle_agg_job::start(req) + .check_aggregation_job_legality(&*app) + .await? + .resolve_task_config(&*app) + .await? + .into_parts(fetch_replay_protection_override(app.kv()).await)?; + + let hpke_receiver_configs = app.get_hpke_receiver_configs(req.version).await?; + + app.async_aggregation_queue() + .send(&queues::AsyncAggregationMessage { + version, + part_batch_sel: transition.part_batch_sel, + agg_job_id: req.resource_id, + taskprov_advertisement: req + .taskprov_advertisement + .as_ref() + .map(|t| t.serialize_to_header_value(version)) + .transpose()?, + initialize_reports: compute_offload::InitializeReports { + hpke_keys: Cow::Borrowed(hpke_receiver_configs.as_ref()), + valid_report_range: app.valid_report_time_range(), + task_id: req.task_id, + task_config: (&transition.task_config).into(), + agg_param: Cow::Borrowed(&req.payload.agg_param), + prep_inits: req.payload.prep_inits, + }, + }) + .await + .map_err(|e| fatal_error!(err = ?e, "failed to queue response")) + } + .await; + + let elapsed = elapsed(&now); + + app.server_metrics().aggregate_job_latency(elapsed); + + AxumDapResponse::from_result_with_success_code( + queue_result.and_then(|()| { + Ok(DapResponse { + version, + media_type: daphne::constants::DapMediaType::AggregationJobResp, + payload: AggregationJobResp::Processing + .get_encoded_with_param(&version) + .map_err(DapError::encoding)?, + }) + }), + app.server_metrics(), + StatusCode::CREATED, + ) +} + +async fn agg_job_draft9( State(app): State>, DapRequestExtractor(req): DapRequestExtractor, ) -> AxumDapResponse { @@ -109,6 +179,27 @@ async fn agg_job( ) } +#[tracing::instrument( + skip_all, + fields( + media_type = ?req.media_type, + task_id = ?req.task_id, + version = ?req.version, + ) +)] +async fn poll_agg_job( + State(app): State>, + DapRequestExtractor(req): DapRequestExtractor, +) -> AxumDapResponse +where + A: DapHelper + DaphneService + Send + Sync, +{ + AxumDapResponse::from_result( + helper::handle_agg_job_poll_req(&*app, req).await, + app.server_metrics(), + ) +} + #[tracing::instrument( skip_all, fields( @@ -124,6 +215,7 @@ async fn agg_share( where A: DapHelper + DaphneService + Send + Sync, { + tracing::info!("broooooooooo"); AxumDapResponse::from_result( helper::handle_agg_share_req(&*app, req).await, app.server_metrics(), diff --git a/crates/daphne-worker/src/durable/agg_job_response_store.rs b/crates/daphne-worker/src/durable/agg_job_response_store.rs new file mode 100644 index 000000000..77110bd3a --- /dev/null +++ b/crates/daphne-worker/src/durable/agg_job_response_store.rs @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Durable Object for storing the result of an aggregation job. + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::protocol::ReadyAggregationJobResp; +use daphne_service_utils::durable_requests::bindings::{ + self, agg_job_response_store, DurableMethod as _, +}; +use std::{sync::OnceLock, time::Duration}; +use worker::{js_sys, Env, Request, Response, Result, ScheduledTime, State}; + +const AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX: &str = "dap_agg_response_chunk"; + +super::mk_durable_object! { + /// Where the aggregate share is stored. For the binding name see its + /// [`BINDING`](bindings::AggregateStore::BINDING) + struct AggJobResponseStore { + state: State, + env: Env, + agg_job_resp: Option, + } +} + +impl AggJobResponseStore { + async fn get_agg_job_response(&mut self) -> Result> { + let agg_job_resp = if let Some(agg_job_resp) = self.agg_job_resp.take() { + agg_job_resp + } else { + let Some(agg_job_resp) = self + .load_chuncked_value(AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX) + .await? + else { + return Ok(None); + }; + agg_job_resp + }; + + self.agg_job_resp = Some(agg_job_resp); + + Ok(self.agg_job_resp.as_ref()) + } + + fn put_agg_job_response(&mut self, resp: ReadyAggregationJobResp) -> Result { + let obj = self.serialize_chunked_value(AGGREGATE_RESPONSE_CHUNK_KEY_PREFIX, &resp, None)?; + self.agg_job_resp = Some(resp); + Ok(obj) + } +} + +impl GcDurableObject for AggJobResponseStore { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + agg_job_resp: None, + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match agg_job_response_store::Command::try_from_uri(&req.path()) { + // Store an aggregate share and aggregation job response. + // + // Idempotent + // Input: `agg_share_dellta: agg_job_result_store::FinishRequest` + // Output: `agg_job_result_store::FinishResponse` + Some(agg_job_response_store::Command::Put) => { + let response = req_parse::(&mut req).await?; + + self.state + .storage() + .put_multiple_raw(self.put_agg_job_response(response)?) + .await?; + + Response::from_json(&()) + } + + // Get the AggregationJobResp + // + // Idempotent + // Output: `Option` + Some(agg_job_response_store::Command::Get) => { + let response = self.get_agg_job_response().await?; + Response::from_json(&response) + } + + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DAP_DURABLE_AGGREGATE_STORE_GC_AFTER_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/durable/aggregate_store_v2.rs b/crates/daphne-worker/src/durable/aggregate_store_v2.rs new file mode 100644 index 000000000..60014a54f --- /dev/null +++ b/crates/daphne-worker/src/durable/aggregate_store_v2.rs @@ -0,0 +1,178 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +//! Durable Object for storing the result of an aggregation job. + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::DapAggregateShare; +use daphne_service_utils::durable_requests::bindings::{ + self, aggregate_store_v2, DurableMethod as _, +}; +use futures::{StreamExt, TryStreamExt}; +use std::{sync::OnceLock, time::Duration}; +use worker::{js_sys, wasm_bindgen::JsValue, Env, Request, Response, Result, ScheduledTime, State}; + +const AGGREGATION_JOB_IDS_KEY: &str = "agg-job-ids"; +const COLLECTED_FLAG_KEY: &str = "collected"; + +super::mk_durable_object! { + /// Where the aggregate share is stored. For the binding name see its + /// [`BINDING`](bindings::AggregateStore::BINDING) + struct AggregateStoreV2 { + state: State, + env: Env, + collected: Option, + } +} + +impl AggregateStoreV2 { + async fn get_agg_share(&self, agg_job_id: &str) -> Result> { + self.load_chuncked_value(agg_job_id).await + } + + fn put_agg_share( + &mut self, + agg_job_id: &str, + share: DapAggregateShare, + obj: js_sys::Object, + ) -> Result { + self.serialize_chunked_value(agg_job_id, &share, obj) + } + + async fn is_collected(&mut self) -> Result { + Ok(if let Some(collected) = self.collected { + collected + } else { + let collected = self.get_or_default(COLLECTED_FLAG_KEY).await?; + self.collected = Some(collected); + collected + }) + } +} + +impl GcDurableObject for AggregateStoreV2 { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + collected: None, + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match aggregate_store_v2::Command::try_from_uri(&req.path()) { + // Store an aggregate share and aggregation job response. + // + // Idempotent + // Input: `agg_share_dellta: agg_job_result_store::FinishRequest` + // Output: `agg_job_result_store::FinishResponse` + Some(aggregate_store_v2::Command::Put) => { + let aggregate_store_v2::PutRequest { + agg_job_id, + agg_share_delta, + } = req_parse(&mut req).await?; + + let mut agg_job_ids = self + .get_or_default::>(AGGREGATION_JOB_IDS_KEY) + .await?; + + let chunks_map = js_sys::Object::default(); + + let agg_job_id = agg_job_id.to_string(); + + let chunks_map = self.put_agg_share(&agg_job_id, agg_share_delta, chunks_map)?; + + agg_job_ids.push(agg_job_id); + js_sys::Reflect::set( + &chunks_map, + &JsValue::from_str(AGGREGATION_JOB_IDS_KEY), + &serde_wasm_bindgen::to_value(&agg_job_ids)?, + )?; + + self.state.storage().put_multiple_raw(chunks_map).await?; + + Response::from_json(&()) + } + + // Get the current aggregate share. + // + // Idempotent + // Output: `DapAggregateShare` + Some(aggregate_store_v2::Command::Get) => { + let ids = self + .get_or_default::>(AGGREGATION_JOB_IDS_KEY) + .await?; + let this = &self; + let share = futures::stream::iter(ids) + .map(|id| async move { this.get_agg_share(&id).await }) + .buffer_unordered(8) + .filter_map(|share| async move { share.transpose() }) + .try_fold(DapAggregateShare::default(), |mut acc, share| async move { + acc.merge(share) + .map(|()| acc) + .map_err(|e| worker::Error::RustError(e.to_string())) + }) + .await?; + Response::from_json(&share) + } + + // Mark this bucket as collected. + // + // Idempotent + // Output: `()` + Some(aggregate_store_v2::Command::MarkCollected) => { + self.state.storage().put(COLLECTED_FLAG_KEY, true).await?; + self.collected = Some(true); + Response::from_json(&()) + } + + // Get the value of the flag indicating whether this bucket has been collected. + // + // Idempotent + // Output: `bool` + Some(aggregate_store_v2::Command::CheckCollected) => { + Response::from_json(&self.is_collected().await?) + } + + // Get the value of the flag indicating whether this bucket has been collected. + // + // Idempotent + // Output: `bool` + Some(aggregate_store_v2::Command::AggregateShareCount) => Response::from_json( + &self + .get_or_default::>(AGGREGATION_JOB_IDS_KEY) + .await? + .len(), + ), + + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DAP_DURABLE_AGGREGATE_STORE_GC_AFTER_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/durable/aggregation_job_store.rs b/crates/daphne-worker/src/durable/aggregation_job_store.rs index d2a0d3e54..5ef10709f 100644 --- a/crates/daphne-worker/src/durable/aggregation_job_store.rs +++ b/crates/daphne-worker/src/durable/aggregation_job_store.rs @@ -56,8 +56,10 @@ impl GcDurableObject for AggregationJobStore { Response::from_json(&response) } - Some(aggregation_job_store::Command::ListJobIds) => { - Response::from_json(&self.load_seen_agg_job_ids().await?) + Some(aggregation_job_store::Command::ContainsJob) => { + let agg_job_id = req_parse::(&mut req).await?; + let has = self.has(&agg_job_id.to_string()).await?; + Response::from_json(&has) } None => Err(int_err(format!( "AggregationJobStore: unexpected request: method={:?}; path={:?}", diff --git a/crates/daphne-worker/src/durable/mod.rs b/crates/daphne-worker/src/durable/mod.rs index b719e8344..30332b7bf 100644 --- a/crates/daphne-worker/src/durable/mod.rs +++ b/crates/daphne-worker/src/durable/mod.rs @@ -19,12 +19,14 @@ //! To know what values to provide to the `name` and `class_name` fields see each type exported by //! this module as well as the [`instantiate_durable_object`] macro, respectively. +pub(crate) mod agg_job_response_store; pub(crate) mod aggregate_store; +pub(crate) mod aggregate_store_v2; pub(crate) mod aggregation_job_store; +pub(crate) mod replay_checker; #[cfg(feature = "test-utils")] pub(crate) mod test_state_cleaner; -use crate::tracing_utils::shorten_paths; use daphne_service_utils::{ capnproto::{CapnprotoPayloadDecode, CapnprotoPayloadDecodeExt}, durable_requests::bindings::DurableMethod, @@ -33,8 +35,11 @@ use serde::{Deserialize, Serialize}; use tracing::info_span; use worker::{Env, Error, Request, Response, Result, ScheduledTime, State}; +pub use agg_job_response_store::AggJobResponseStore; pub use aggregate_store::AggregateStore; +pub use aggregate_store_v2::AggregateStoreV2; pub use aggregation_job_store::AggregationJobStore; +pub use replay_checker::ReplayChecker; const ERR_NO_VALUE: &str = "No such value in storage."; @@ -164,12 +169,93 @@ macro_rules! mk_durable_object { } #[allow(dead_code)] + /// Set a key/value pair unless the key already exists. If the key exists, then return the current + /// value. Otherwise return nothing. async fn put_if_not_exists(&self, key: &str, val: &T) -> ::worker::Result> where T: ::serde::de::DeserializeOwned + ::serde::Serialize, { $crate::durable::state_set_if_not_exists(&self.state, key, val).await } + + #[allow(dead_code)] + async fn has(&self, key: &str) -> ::worker::Result { + $crate::durable::state_contains_key(&self.state, key).await + } + + #[allow(dead_code)] + async fn load_chuncked_value(&self, prefix: &str) -> ::worker::Result> + where + T: daphne_service_utils::capnproto::CapnprotoPayloadDecode, + { + let Some(count) = self.get::(&format!("{prefix}_count")).await? else { + return Ok(None); + }; + + let keys = &$crate::durable::calculate_chunk_keys(count, prefix); + let map = self.state.storage().get_multiple(keys.clone()).await?; + let bytes = keys + .iter() + .map(|k| wasm_bindgen::JsValue::from_str(k.as_ref())) + .filter(|k| map.has(k)) + .map(|k| map.get(&k)) + .map(|js_v| { + serde_wasm_bindgen::from_value::>(js_v).expect("expect an array of bytes") + }) + .reduce(|mut buf, v| { + buf.extend_from_slice(&v); + buf + }); + + let Some(bytes) = bytes else { + return Ok(None); + }; + + ::decode_from_bytes(&bytes) + .map(Some) + .map_err(|e| worker::Error::RustError(e.to_string())) + } + + #[allow(dead_code)] + fn serialize_chunked_value( + &self, + prefix: &str, + value: &T, + object_to_fill: impl Into> + ) -> ::worker::Result + where + T: daphne_service_utils::capnproto::CapnprotoPayloadEncode, + { + let object_to_fill = object_to_fill.into().unwrap_or_default(); + use daphne_service_utils::capnproto::CapnprotoPayloadEncodeExt; + let bytes = value.encode_to_bytes(); + let chunk_keys = $crate::durable::chunk_keys_for(&bytes, prefix); + let mut base_idx = 0; + for key in &chunk_keys { + let end = usize::min(base_idx + $crate::durable::MAX_CHUNK_SIZE, bytes.len()); + let chunk = &bytes[base_idx..end]; + + // unwrap cannot fail because chunk len is bounded by MAX_CHUNK_SIZE which is smaller than + // u32::MAX + let value = worker::js_sys::Uint8Array::new_with_length(u32::try_from(chunk.len()).unwrap()); + value.copy_from(chunk); + + worker::js_sys::Reflect::set( + &object_to_fill, + &wasm_bindgen::JsValue::from_str(key.as_ref()), + &value.into(), + )?; + + base_idx = end; + } + worker::js_sys::Reflect::set( + &object_to_fill, + &wasm_bindgen::JsValue::from_str(&format!("{prefix}_count")), + &chunk_keys.len().into(), + )?; + + Ok(object_to_fill) + } } }; } @@ -220,6 +306,20 @@ pub(crate) async fn state_set_if_not_exists Deserialize<'a> + Seriali Ok(None) } +pub(crate) async fn state_contains_key(state: &State, key: &str) -> Result { + struct DevNull; + impl<'de> Deserialize<'de> for DevNull { + fn deserialize(_: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + Ok(Self) + } + } + + Ok(state_get::(state, key).await?.is_some()) +} + async fn req_parse(req: &mut Request) -> Result where T: CapnprotoPayloadDecode, @@ -229,11 +329,36 @@ where fn create_span_from_request(req: &Request) -> tracing::Span { let path = req.path(); - let span = info_span!("DO span", p = %shorten_paths(path.split('/')).display()); + let span = info_span!("DO span", ?path); span.in_scope(|| tracing::info!(path, "DO handling new request")); span } +/// The maximum chunk size as documented in +/// [the worker docs](https://developers.cloudflare.com/durable-objects/platform/limits/) +const MAX_CHUNK_SIZE: usize = 128_000; + +fn chunk_keys_for(bytes: &[u8], prefix: &str) -> Vec { + // stolen from + // https://doc.rust-lang.org/std/primitive.usize.html#method.div_ceil + // because it's nightly only + fn div_ceil(lhs: usize, rhs: usize) -> usize { + let d = lhs / rhs; + let r = lhs % rhs; + if r > 0 && rhs > 0 { + d + 1 + } else { + d + } + } + + calculate_chunk_keys(div_ceil(bytes.len(), MAX_CHUNK_SIZE), prefix) +} + +fn calculate_chunk_keys(count: usize, prefix: &str) -> Vec { + (0..count).map(|i| format!("{prefix}_{i:04}")).collect() +} + /// Instantiate a durable object. /// /// # Syntax diff --git a/crates/daphne-worker/src/durable/replay_checker.rs b/crates/daphne-worker/src/durable/replay_checker.rs new file mode 100644 index 000000000..1b4aa63c7 --- /dev/null +++ b/crates/daphne-worker/src/durable/replay_checker.rs @@ -0,0 +1,118 @@ +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause + +use super::{req_parse, GcDurableObject}; +use crate::int_err; +use daphne::messages::{AggregationJobId, ReportId}; +use daphne_service_utils::durable_requests::bindings::{self, replay_checker, DurableMethod}; +use std::{ + collections::{HashMap, HashSet}, + iter::zip, + sync::OnceLock, + time::Duration, +}; +use wasm_bindgen::JsValue; +use worker::{js_sys, Env, Request, Response, Result, ScheduledTime, State}; + +super::mk_durable_object! { + /// Where report ids are stored for replay protection. + struct ReplayChecker { + state: State, + env: Env, + seen: HashMap, + } +} + +impl GcDurableObject for ReplayChecker { + type DurableMethod = bindings::AggregateStore; + + fn with_state_and_env(state: State, env: Env) -> Self { + Self { + state, + env, + seen: Default::default(), + } + } + + async fn handle(&mut self, mut req: Request) -> Result { + match replay_checker::Command::try_from_uri(&req.path()) { + Some(replay_checker::Command::Check) => { + let replay_checker::Request { + report_ids, + aggregation_job_id, + } = req_parse(&mut req).await?; + + let mut duplicates = HashSet::new(); + + let report_ids_as_string = report_ids + .iter() + .filter(|r| match self.seen.get(r) { + Some(cached_agg_job_id) => { + if *cached_agg_job_id != aggregation_job_id { + duplicates.insert(**r); + } + false // skip checking + } + None => true, // check against disk + }) + .map(ToString::to_string) + .collect::>(); + + let aggregation_job_id_as_str = aggregation_job_id.to_string(); + + let result = self + .state + .storage() + .get_multiple(report_ids_as_string.clone()) + .await?; + + let obj_to_update = js_sys::Object::new(); + for (id, as_str) in zip(report_ids.iter(), &report_ids_as_string) { + self.seen.insert(*id, aggregation_job_id); + + let v = result.get(&JsValue::from_str(as_str)); + if let Some(stored_agg_job_id) = v.as_string() { + if stored_agg_job_id != aggregation_job_id_as_str { + duplicates.insert(*id); + } + } else { + js_sys::Reflect::set( + &obj_to_update, + &JsValue::from_str(as_str), + &JsValue::from_str(aggregation_job_id_as_str.as_ref()), + )?; + } + } + + self.state.storage().put_multiple_raw(obj_to_update).await?; + + Response::from_json(&replay_checker::Response { duplicates }) + } + None => Err(int_err(format!( + "AggregatesStore: unexpected request: method={:?}; path={:?}", + req.method(), + req.path() + ))), + } + } + + fn should_cleanup_at(&self) -> Option { + const VAR_NAME: &str = "DO_REPLAY_CHECKER_GC_SECS"; + static SELF_DELETE_AFTER: OnceLock = OnceLock::new(); + + let duration = SELF_DELETE_AFTER.get_or_init(|| { + Duration::from_secs( + self.env + .var(VAR_NAME) + .map(|v| { + v.to_string().parse().unwrap_or_else(|e| { + panic!("{VAR_NAME} could not be parsed as a number of seconds: {e}") + }) + }) + .unwrap_or(60 * 60 * 24 * 7), // one week + ) + }); + + Some(ScheduledTime::from(*duration)) + } +} diff --git a/crates/daphne-worker/src/durable/test_state_cleaner.rs b/crates/daphne-worker/src/durable/test_state_cleaner.rs index e2b065116..4e8a87e22 100644 --- a/crates/daphne-worker/src/durable/test_state_cleaner.rs +++ b/crates/daphne-worker/src/durable/test_state_cleaner.rs @@ -36,6 +36,7 @@ impl DurableObject for TestStateCleaner { } impl TestStateCleaner { + #[tracing::instrument(skip_all)] async fn handle(&mut self, mut req: Request) -> Result { let durable = DurableConnector::new(&self.env); match bindings::TestStateCleaner::try_from_uri(&req.path()) { @@ -69,19 +70,31 @@ impl TestStateCleaner { let queued: Vec> = DurableOrdered::get_all(&self.state, "object").await?; for durable_ref in queued.iter().map(|queued| queued.as_ref()) { - durable + let result = durable .post_by_id_hex::<_, ()>( &durable_ref.binding, bindings::TestStateCleaner::DeleteAll.to_uri(), durable_ref.id_hex.clone(), &(), ) - .await?; - console_debug!( - "deleted instance. binding: {binding}. instance: {instance}", - binding = durable_ref.binding, - instance = durable_ref.id_hex, - ); + .await; + match result { + Ok(()) => { + console_debug!( + "deleted instance. binding: {binding}. instance: {instance}", + binding = durable_ref.binding, + instance = durable_ref.id_hex, + ); + } + Err(e) => { + console_error!( + "failed to delete instance. binding: {binding}. instance: {instance}. error: {e}", + binding = durable_ref.binding, + instance = durable_ref.id_hex, + e = e, + ) + } + } } self.state.storage().delete_all().await?; @@ -184,6 +197,7 @@ impl AsRef for DurableOrdered { } } +#[tracing::instrument(skip(state))] async fn get_front Deserialize<'a> + Serialize>( state: &State, prefix: &str, diff --git a/crates/daphne-worker/src/lib.rs b/crates/daphne-worker/src/lib.rs index 6b8fb25ea..16d2cde9c 100644 --- a/crates/daphne-worker/src/lib.rs +++ b/crates/daphne-worker/src/lib.rs @@ -30,3 +30,17 @@ pub(crate) fn int_err(s: S) -> Error { pub(crate) fn elapsed(date: &worker::Date) -> Duration { Duration::from_millis(worker::Date::now().as_millis() - date.as_millis()) } + +pub(crate) use daphne_service_utils::base_capnp; +pub(crate) use daphne_service_utils::compute_offload_capnp; + +mod queue_messages_capnp { + #![allow(dead_code)] + #![allow(clippy::pedantic)] + #![allow(clippy::needless_lifetimes)] + #![allow(clippy::extra_unused_type_parameters)] + include!(concat!( + env!("OUT_DIR"), + "/src/aggregator/queues/queue_messages_capnp.rs" + )); +} diff --git a/crates/daphne-worker/src/storage/kv/mod.rs b/crates/daphne-worker/src/storage/kv/mod.rs index 6cd83de2b..fd58b08d1 100644 --- a/crates/daphne-worker/src/storage/kv/mod.rs +++ b/crates/daphne-worker/src/storage/kv/mod.rs @@ -2,11 +2,9 @@ // SPDX-License-Identifier: BSD-3-Clause mod cache; -// mod request_coalescer; use std::{any::Any, fmt::Display, future::Future, sync::RwLock}; -use daphne_service_utils::durable_requests::KV_PATH_PREFIX; use mappable_rc::Marc; use serde::{de::DeserializeOwned, Serialize}; use tracing::{info_span, Instrument}; @@ -167,9 +165,12 @@ pub(crate) enum GetOrInsertError { Other(E), } -impl From for GetOrInsertError { - fn from(error: Error) -> Self { - Self::StorageProxy(error) +impl From for GetOrInsertError +where + Error: From, +{ + fn from(error: E2) -> Self { + Self::StorageProxy(error.into()) } } @@ -440,6 +441,6 @@ impl<'h> Kv<'h> { } fn to_key(key: &P::Key) -> String { - format!("{KV_PATH_PREFIX}/{}/{key}", P::PREFIX) + format!("{}/{key}", P::PREFIX) } } diff --git a/crates/daphne-worker/src/storage/mod.rs b/crates/daphne-worker/src/storage/mod.rs index fce3c87ec..722a0f24f 100644 --- a/crates/daphne-worker/src/storage/mod.rs +++ b/crates/daphne-worker/src/storage/mod.rs @@ -55,17 +55,12 @@ impl> RequestBuilder<'_, B, P> { where R: DeserializeOwned, { - tracing::debug!( - obj = std::any::type_name::().split("::").last().unwrap(), - path = ?self.path, - "requesting DO", - ); let resp = storage_proxy::handle_do_request( self.durable.env, Default::default(), self.path.to_uri(), self.request, - |_, _, _| {}, + |_, _, _| {}, // retry metric ) .await?; diff --git a/crates/daphne-worker/src/storage_proxy/mod.rs b/crates/daphne-worker/src/storage_proxy/mod.rs index 661c9cf68..7a5bc40e7 100644 --- a/crates/daphne-worker/src/storage_proxy/mod.rs +++ b/crates/daphne-worker/src/storage_proxy/mod.rs @@ -257,6 +257,7 @@ async fn retry(mut f: F) -> Result where F: FnMut(u8) -> Fut, Fut: Future>, + E: std::fmt::Debug, { const RETRY_DELAYS: &[Duration] = &[ Duration::from_millis(1_000), @@ -272,8 +273,10 @@ where Err(error) => { if attempt < attempts { Delay::from(RETRY_DELAYS[usize::from(attempt - 1)]).await; + tracing::warn!(attempt, error = ?error, "failed, retrying..."); attempt += 1; } else { + tracing::error!(attempt, error = ?error, "failed, aborting..."); return Err(error); } } @@ -461,7 +464,7 @@ pub async fn handle_do_request>( let mut do_req = RequestInit::new(); do_req.with_method(worker::Method::Post); do_req.with_headers(headers.into()); - tracing::debug!( + tracing::trace!( len = durable_request.body().len(), "deserializing do request" ); diff --git a/crates/daphne/src/lib.rs b/crates/daphne/src/lib.rs index a3ad3fc27..dbb5838ce 100644 --- a/crates/daphne/src/lib.rs +++ b/crates/daphne/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2022 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! This crate implements the core protocol logic for the Distributed Aggregation Protocol @@ -47,7 +47,7 @@ pub mod hpke; pub mod messages; pub mod metrics; pub mod pine; -pub(crate) mod protocol; +pub mod protocol; pub mod roles; pub mod taskprov; #[cfg(any(test, feature = "test-utils"))] @@ -355,7 +355,9 @@ impl IntoIterator for DapAggregateSpan { } impl ReportId { - fn shard(&self, num_shards: NonZeroUsize) -> usize { + /// Deterministically calculate a number between 0 and `num_shards` based on the report id. + /// Usefull for sharding datastores. + pub fn shard(&self, num_shards: NonZeroUsize) -> usize { // NOTE This sharding scheme does not evenly distribute reports across all shards. // // First, the clients are supposed to choose the report ID at random; by finding collisions diff --git a/crates/daphne/src/messages/request.rs b/crates/daphne/src/messages/request.rs index 2a9fa4ee8..f5878de78 100644 --- a/crates/daphne/src/messages/request.rs +++ b/crates/daphne/src/messages/request.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::ops::Deref; @@ -16,6 +16,9 @@ pub trait RequestBody { type ResourceId; } +/// A poll request has no body, but requires a `AggregationJobId`. +pub struct PollAggregationJob; + /// A poll request has no body, but requires a `CollectionJobId`. pub struct CollectionPollReq; @@ -33,6 +36,7 @@ impl_req_body! { Report | () AggregationJobInitReq | AggregationJobId HashedAggregationJobReq | AggregationJobId + PollAggregationJob | AggregationJobId AggregateShareReq | () CollectionReq | CollectionJobId CollectionPollReq | CollectionJobId diff --git a/crates/daphne/src/protocol/aggregator.rs b/crates/daphne/src/protocol/aggregator.rs index 0925c3bc7..568960936 100644 --- a/crates/daphne/src/protocol/aggregator.rs +++ b/crates/daphne/src/protocol/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use super::{ diff --git a/crates/daphne/src/protocol/mod.rs b/crates/daphne/src/protocol/mod.rs index 162947fd6..e8d6565dc 100644 --- a/crates/daphne/src/protocol/mod.rs +++ b/crates/daphne/src/protocol/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use crate::messages; @@ -565,7 +565,7 @@ mod test { agg_job_resp.prep_resps[1] = tmp; assert_matches!( - t.consume_agg_job_resp_expect_err(leader_state, agg_job_resp), + t.consume_agg_job_resp_expect_err(leader_state, agg_job_resp,), DapError::Abort(DapAbort::InvalidMessage { .. }) ); } diff --git a/crates/daphne/src/protocol/report_init.rs b/crates/daphne/src/protocol/report_init.rs index 0cfcd6723..94b308052 100644 --- a/crates/daphne/src/protocol/report_init.rs +++ b/crates/daphne/src/protocol/report_init.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2024 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use crate::{ @@ -295,7 +295,7 @@ impl

InitializedReport

{ } } - pub(crate) fn metadata(&self) -> &ReportMetadata { + pub fn metadata(&self) -> &ReportMetadata { match self { Self::Ready { metadata, .. } | Self::Rejected { metadata, .. } => metadata, } diff --git a/crates/daphne/src/roles/aggregator.rs b/crates/daphne/src/roles/aggregator.rs index d9dd4bac9..f84c27e0f 100644 --- a/crates/daphne/src/roles/aggregator.rs +++ b/crates/daphne/src/roles/aggregator.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause use std::{collections::HashSet, ops::Range}; @@ -80,13 +80,19 @@ pub trait DapAggregator: HpkeProvider + Sized { /// collected batch. async fn is_batch_overlapping( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result; /// Check whether the given batch ID has been observed before. This is called by the Leader /// (resp. Helper) in response to a CollectReq (resp. AggregateShareReq) for leader-selected tasks. - async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result; + async fn batch_exists( + &self, + version: DapVersion, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result; /// Store a set of output shares and mark the corresponding reports as aggregated. /// @@ -113,6 +119,7 @@ pub trait DapAggregator: HpkeProvider + Sized { /// Fetch the aggregate share for the given batch. async fn get_agg_share( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result; @@ -120,6 +127,7 @@ pub trait DapAggregator: HpkeProvider + Sized { /// Mark a batch as collected. async fn mark_collected( &self, + version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result<(), DapError>; diff --git a/crates/daphne/src/roles/helper/handle_agg_job.rs b/crates/daphne/src/roles/helper/handle_agg_job.rs index dad6f04be..4a667d0be 100644 --- a/crates/daphne/src/roles/helper/handle_agg_job.rs +++ b/crates/daphne/src/roles/helper/handle_agg_job.rs @@ -6,14 +6,19 @@ use crate::{ error::DapAbort, messages::{ AggregationJobInitReq, AggregationJobResp, PartialBatchSelector, PrepareRespVar, - ReportError, TaskId, + ReportError, ReportId, ReportMetadata, TaskId, }, metrics::ReportStatus, - protocol::aggregator::ReportProcessedStatus, + protocol::{aggregator::ReportProcessedStatus, ReadyAggregationJobResp}, roles::{aggregator::MergeAggShareError, resolve_task_config}, - DapAggregationParam, DapError, DapRequest, DapTaskConfig, InitializedReport, WithPeerPrepShare, + DapAggregateShare, DapAggregateSpan, DapAggregationParam, DapError, DapRequest, DapTaskConfig, + InitializedReport, WithPeerPrepShare, +}; +use std::{ + collections::{HashMap, HashSet}, + future::Future, + sync::Once, }; -use std::{collections::HashMap, sync::Once}; /// A state machine for the handling of aggregation jobs. pub struct HandleAggJob { @@ -38,7 +43,6 @@ pub struct WithTaskConfig { /// /// This type is returned by [`HandleAggJob::into_parts`] and [`Self::with_initialized_reports`] /// can be used to return to the [`HandleAggJob`] state machine flow. -#[non_exhaustive] pub struct ToInitializedReportsTransition { pub task_id: TaskId, pub part_batch_sel: PartialBatchSelector, @@ -54,6 +58,8 @@ pub struct InitializedReports { reports: Vec>, } +pub struct UniqueInitializedReports(InitializedReports); + macro_rules! impl_from { ($($t:ty),*$(,)?) => { $(impl From<$t> for HandleAggJob<$t> { @@ -350,4 +356,58 @@ impl HandleAggJob { // enabling an DOS attack. Err(DapAbort::BadRequest("aggregation job contained too many replays".into()).into()) } + + pub async fn check_for_replays( + mut self, + replay_check: F, + ) -> Result, E> + where + F: FnOnce(&mut dyn Iterator) -> Fut, + Fut: Future, E>>, + { + let replays = replay_check( + &mut self + .state + .reports + .iter() + .filter(|r| matches!(r, InitializedReport::Ready { .. })) + .map(|r| r.metadata()), + ) + .await?; + + for r in &mut self.state.reports { + if replays.contains(&r.metadata().id) { + *r = InitializedReport::Rejected { + metadata: r.metadata().clone(), + report_err: ReportError::ReportReplayed, + } + } + } + + Ok(HandleAggJob { + state: UniqueInitializedReports(self.state), + }) + } +} + +impl HandleAggJob { + pub fn finish( + self, + ) -> Result<(DapAggregateSpan, ReadyAggregationJobResp), DapError> { + let InitializedReports { + task_id, + part_batch_sel, + task_config, + reports, + agg_param, + } = self.state.0; + + task_config.produce_agg_job_resp( + task_id, + &agg_param, + &Default::default(), + &part_batch_sel, + &reports, + ) + } } diff --git a/crates/daphne/src/roles/helper/mod.rs b/crates/daphne/src/roles/helper/mod.rs index 2ae938bae..40ec8f2b7 100644 --- a/crates/daphne/src/roles/helper/mod.rs +++ b/crates/daphne/src/roles/helper/mod.rs @@ -11,8 +11,8 @@ use crate::{ constants::DapMediaType, error::DapAbort, messages::{ - constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobId, - AggregationJobInitReq, PartialBatchSelector, TaskId, + constant_time_eq, request::PollAggregationJob, AggregateShare, AggregateShareReq, + AggregationJobId, AggregationJobInitReq, AggregationJobResp, PartialBatchSelector, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, protocol::aggregator::ReplayProtection, @@ -98,6 +98,14 @@ pub trait DapHelper: DapAggregator { task_id: &TaskId, req_hash: &AggregationJobRequestHash, ) -> Result<(), DapError>; + + /// Polls for the completion of an aggregation job. + async fn poll_aggregated( + &self, + version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result; } pub async fn handle_agg_job_init_req( @@ -129,6 +137,22 @@ pub async fn handle_agg_job_init_req( }) } +pub async fn handle_agg_job_poll_req( + aggregator: &A, + req: DapRequest, +) -> Result { + let response = aggregator + .poll_aggregated(req.version, &req.task_id, &req.resource_id) + .await?; + Ok(DapResponse { + version: req.version, + media_type: DapMediaType::AggregationJobResp, + payload: response + .get_encoded_with_param(&req.version) + .map_err(DapError::encoding)?, + }) +} + /// Handle a request for an aggregate share. This is called by the Leader to complete a /// collection job. pub async fn handle_agg_share_req( @@ -159,8 +183,9 @@ pub async fn handle_agg_share_req( ) .await?; + tracing::info!("jaljskdlasjdklajskljdklajskldklajskldjkalsjdkl"); let agg_share = aggregator - .get_agg_share(&task_id, &req.payload.batch_sel) + .get_agg_share(req.version, &task_id, &req.payload.batch_sel) .await?; // Check that we have aggreagted the same set of reports as the Leader. @@ -194,7 +219,7 @@ pub async fn handle_agg_share_req( // Mark each aggregated report as collected. aggregator - .mark_collected(&task_id, &req.payload.batch_sel) + .mark_collected(task_config.version, &task_id, &req.payload.batch_sel) .await?; let encrypted_agg_share = task_config.produce_helper_encrypted_agg_share( diff --git a/crates/daphne/src/roles/leader/mod.rs b/crates/daphne/src/roles/leader/mod.rs index a188b6555..9b1dcf301 100644 --- a/crates/daphne/src/roles/leader/mod.rs +++ b/crates/daphne/src/roles/leader/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause pub mod in_memory_leader; @@ -23,7 +23,7 @@ use crate::{ CollectionReq, Extension, Interval, PartialBatchSelector, Query, Report, TaskId, }, metrics::{DaphneRequestType, ReportStatus}, - protocol, + protocol::{self, ReadyAggregationJobResp}, roles::resolve_task_config, DapAggregationParam, DapCollectionJob, DapError, DapLeaderProcessTelemetry, DapRequest, DapRequestMeta, DapResponse, DapTaskConfig, DapVersion, @@ -39,6 +39,7 @@ struct LeaderHttpRequestOptions<'p, P> { } enum LeaderHttpRequestMethod { + Get, Post, Put, } @@ -74,6 +75,7 @@ where }; let resp = match method { + LeaderHttpRequestMethod::Get => role.send_http_get(meta, url).await?, LeaderHttpRequestMethod::Put => role.send_http_put(meta, url, req_data).await?, LeaderHttpRequestMethod::Post => role.send_http_post(meta, url, req_data).await?, }; @@ -151,6 +153,8 @@ pub trait DapLeader: DapAggregator { collect_resp: &Collection, ) -> Result<(), DapError>; + async fn send_http_get(&self, req: DapRequestMeta, url: Url) -> Result; + /// Send an HTTP POST request. async fn send_http_post

( &self, @@ -369,20 +373,39 @@ async fn run_agg_job( }, ) .await?; - let agg_job_resp = + let mut agg_job_resp = AggregationJobResp::get_decoded_with_param(&task_config.version, &resp.payload) .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; - let agg_job_resp = match agg_job_resp { - AggregationJobResp::Ready { prep_resps } => { - crate::protocol::ReadyAggregationJobResp { prep_resps } - } - AggregationJobResp::Processing => todo!("polling not implemented yet"), + let ready = loop { + let resp = match agg_job_resp { + AggregationJobResp::Ready { prep_resps } => { + break ReadyAggregationJobResp { prep_resps } + } + AggregationJobResp::Processing => { + leader_send_http_request( + aggregator, + task_id, + task_config, + LeaderHttpRequestOptions { + path: &url_path, + req_media_type: DapMediaType::AggregationJobInitReq, + resp_media_type: DapMediaType::AggregationJobResp, + req_data: (), + method: LeaderHttpRequestMethod::Get, + taskprov_advertisement: taskprov_advertisement.clone(), + }, + ) + .await? + } + }; + agg_job_resp = + AggregationJobResp::get_decoded_with_param(&task_config.version, &resp.payload) + .map_err(|e| DapAbort::from_codec_error(e, *task_id))?; }; // Handle AggregationJobResp. - let agg_span = - task_config.consume_agg_job_resp(task_id, agg_job_state, agg_job_resp, metrics)?; + let agg_span = task_config.consume_agg_job_resp(task_id, agg_job_state, ready, metrics)?; let out_shares_count = agg_span.report_count() as u64; debug!("computed {out_shares_count} output shares"); @@ -441,7 +464,11 @@ async fn run_coll_job( let metrics = aggregator.metrics(); debug!("collecting id {coll_job_id}"); - let leader_agg_share = aggregator.get_agg_share(task_id, batch_sel).await?; + let leader_agg_share = aggregator + // we need to hardcode the dap version here because the leader uses the old aggregate store + // method of storing the aggregate share + .get_agg_share(DapVersion::Draft09, task_id, batch_sel) + .await?; let taskprov_advertisement = task_config.resolve_taskprove_advertisement()?; @@ -522,7 +549,7 @@ async fn run_coll_job( // Mark reports as collected. aggregator - .mark_collected(task_id, &agg_share_req_batch_sel) + .mark_collected(task_config.version, task_id, &agg_share_req_batch_sel) .await?; metrics.report_inc_by(ReportStatus::Collected, agg_share_req_report_count); diff --git a/crates/daphne/src/roles/mod.rs b/crates/daphne/src/roles/mod.rs index 2ab83e826..f47f97857 100644 --- a/crates/daphne/src/roles/mod.rs +++ b/crates/daphne/src/roles/mod.rs @@ -1,4 +1,4 @@ -// Copyright (c) 2023 Cloudflare, Inc. All rights reserved. +// Copyright (c) 2025 Cloudflare, Inc. All rights reserved. // SPDX-License-Identifier: BSD-3-Clause //! Trait definitions for Daphne backends. @@ -67,7 +67,10 @@ async fn check_batch( } (DapBatchMode::LeaderSelected { .. }, Query::LeaderSelectedCurrentBatch) => (), // nothing to do (DapBatchMode::LeaderSelected { .. }, Query::LeaderSelectedByBatchId { batch_id }) => { - if !agg.batch_exists(task_id, batch_id).await? { + if !agg + .batch_exists(task_config.version, task_id, batch_id) + .await? + { return Err(DapAbort::BatchInvalid { detail: format!( "The queried batch ({}) does not exist.", @@ -83,7 +86,10 @@ async fn check_batch( // Check that the batch does not overlap with any previously collected batch. if let Some(batch_sel) = query.into_batch_sel() { - if agg.is_batch_overlapping(task_id, &batch_sel).await? { + if agg + .is_batch_overlapping(task_config.version, task_id, &batch_sel) + .await? + { return Err(DapAbort::batch_overlap(task_id, query).into()); } } @@ -91,7 +97,7 @@ async fn check_batch( Ok(()) } -async fn resolve_task_config( +pub async fn resolve_task_config( agg: &impl DapAggregator, req: &DapRequestMeta, ) -> Result { diff --git a/crates/daphne/src/testing/mod.rs b/crates/daphne/src/testing/mod.rs index fd3f1fbac..1915c8130 100644 --- a/crates/daphne/src/testing/mod.rs +++ b/crates/daphne/src/testing/mod.rs @@ -531,6 +531,7 @@ pub struct InMemoryAggregator { // Helper: aggregation jobs processed_jobs: Mutex>, + finished_jobs: Mutex>, } impl DeepSizeOf for InMemoryAggregator { @@ -547,6 +548,7 @@ impl DeepSizeOf for InMemoryAggregator { taskprov_vdaf_verify_key_init, peer, processed_jobs, + finished_jobs, } = self; global_config.deep_size_of_children(context) + tasks.deep_size_of_children(context) @@ -557,6 +559,7 @@ impl DeepSizeOf for InMemoryAggregator { + taskprov_vdaf_verify_key_init.deep_size_of_children(context) + peer.deep_size_of_children(context) + processed_jobs.deep_size_of_children(context) + + finished_jobs.deep_size_of_children(context) } } @@ -581,6 +584,7 @@ impl InMemoryAggregator { taskprov_vdaf_verify_key_init, peer: None, processed_jobs: Default::default(), + finished_jobs: Default::default(), } } @@ -605,6 +609,7 @@ impl InMemoryAggregator { taskprov_vdaf_verify_key_init, peer: peer.into(), processed_jobs: Default::default(), + finished_jobs: Default::default(), } } @@ -723,6 +728,7 @@ impl DapAggregator for InMemoryAggregator { async fn is_batch_overlapping( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { @@ -746,7 +752,12 @@ impl DapAggregator for InMemoryAggregator { Ok(false) } - async fn batch_exists(&self, task_id: &TaskId, batch_id: &BatchId) -> Result { + async fn batch_exists( + &self, + _version: DapVersion, + task_id: &TaskId, + batch_id: &BatchId, + ) -> Result { let task_config = self .get_task_config_for(task_id) .await? @@ -826,6 +837,7 @@ impl DapAggregator for InMemoryAggregator { async fn get_agg_share( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result { @@ -854,6 +866,7 @@ impl DapAggregator for InMemoryAggregator { async fn mark_collected( &self, + _version: DapVersion, task_id: &TaskId, batch_sel: &BatchSelector, ) -> Result<(), DapError> { @@ -910,6 +923,26 @@ impl DapHelper for InMemoryAggregator { } } } + + async fn poll_aggregated( + &self, + _version: DapVersion, + task_id: &TaskId, + agg_job_id: &AggregationJobId, + ) -> Result { + self.finished_jobs + .lock() + .unwrap() + .get(agg_job_id) + .cloned() + .map(Into::into) + .ok_or_else(|| { + DapError::Abort(DapAbort::UnrecognizedAggregationJob { + task_id: *task_id, + agg_job_id: *agg_job_id, + }) + }) + } } #[async_trait] @@ -999,6 +1032,25 @@ impl DapLeader for InMemoryAggregator { .finish_collect_job(task_id, coll_job_id, collection) } + async fn send_http_get(&self, meta: DapRequestMeta, url: Url) -> Result { + match meta.media_type { + Some(DapMediaType::AggregationJobInitReq) => Ok(helper::handle_agg_job_poll_req( + &**self.peer.as_ref().expect("peer not configured"), + DapRequest { + meta, + resource_id: AggregationJobId::try_from_base64url( + url.path().split('/').last().unwrap(), + ) + .unwrap(), + payload: messages::request::PollAggregationJob, + }, + ) + .await + .expect("peer aborted unexpectedly")), + _ => unreachable!("unhandled media type: {:?}", meta.media_type), + } + } + async fn send_http_post

( &self, meta: DapRequestMeta, diff --git a/interop/Dockerfile.interop_helper b/interop/Dockerfile.interop_helper index f646dd078..a8192c1da 100644 --- a/interop/Dockerfile.interop_helper +++ b/interop/Dockerfile.interop_helper @@ -6,14 +6,15 @@ # NOTE: We must use debian (bookworm). We cannot use alpine because building # the service proxy requires OpenSSL, which is not compatible with the musl # target required by alpine. -FROM rust:1.80-bookworm AS build-deps-common +FROM rust:1.84.1-bookworm AS build-deps-common RUN apt update && apt install -y capnproto clang RUN capnp --version # Prepare dependencies for building the storage proxy. FROM build-deps-common AS build-deps-storage-proxy RUN rustup target add wasm32-unknown-unknown -RUN cargo install worker-build@0.1.1 --locked +RUN echo ola +RUN cargo install worker-build@0.1.2 --locked # Build the service. FROM build-deps-common AS builder-service