Skip to content

Commit

Permalink
Simplify DapRequest generics
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Nov 27, 2024
1 parent 05c3473 commit 6c1d3d8
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 130 deletions.
177 changes: 120 additions & 57 deletions crates/daphne-server/src/router/extractor.rs

Large diffs are not rendered by default.

8 changes: 2 additions & 6 deletions crates/daphne-server/src/router/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use axum::{
routing::{post, put},
};
use daphne::{
messages::{request::resource, AggregateShareReq, AggregationJobInitReq},
messages::{AggregateShareReq, AggregationJobInitReq},
roles::{helper, DapHelper},
};
use http::StatusCode;
Expand Down Expand Up @@ -38,11 +38,7 @@ pub(super) fn add_helper_routes(router: super::Router<App>) -> super::Router<App
)]
async fn agg_job(
State(app): State<Arc<App>>,
DapRequestExtractor(req): DapRequestExtractor<
FROM_LEADER,
AggregationJobInitReq,
resource::AggregationJobId,
>,
DapRequestExtractor(req): DapRequestExtractor<FROM_LEADER, AggregationJobInitReq>,
) -> AxumDapResponse {
let timer = std::time::Instant::now();

Expand Down
15 changes: 4 additions & 11 deletions crates/daphne-server/src/router/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use axum::{
use daphne::{
constants::DapMediaType,
error::DapAbort,
messages::{self, request::resource},
messages::{self, request::CollectionPollReq},
roles::leader::{self, DapLeader},
DapError, DapVersion,
};
Expand Down Expand Up @@ -82,10 +82,7 @@ where
)]
async fn upload<A>(
State(app): State<Arc<A>>,
UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor<
messages::Report,
resource::None,
>,
UnauthenticatedDapRequestExtractor(req): UnauthenticatedDapRequestExtractor<messages::Report>,
) -> Response
where
A: DapLeader + DaphneService + Send + Sync,
Expand All @@ -105,11 +102,7 @@ where
)]
async fn start_collection_job<A>(
State(app): State<Arc<A>>,
DapRequestExtractor(req): DapRequestExtractor<
FROM_COLLECTOR,
messages::CollectionReq,
resource::CollectionJobId,
>,
DapRequestExtractor(req): DapRequestExtractor<FROM_COLLECTOR, messages::CollectionReq>,
) -> Response
where
A: DapLeader + DaphneService + Send + Sync,
Expand All @@ -129,7 +122,7 @@ where
)]
async fn poll_collect<A>(
State(app): State<Arc<A>>,
DapRequestExtractor(req): DapRequestExtractor<FROM_COLLECTOR, (), resource::CollectionJobId>,
DapRequestExtractor(req): DapRequestExtractor<FROM_COLLECTOR, CollectionPollReq>,
) -> Response
where
A: DapLeader + DaphneService + Send + Sync,
Expand Down
54 changes: 33 additions & 21 deletions crates/daphne/src/messages/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,36 @@

use std::ops::Deref;

use super::taskprov::TaskprovAdvertisement;
use super::{
taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobId, AggregationJobInitReq,
CollectionJobId, CollectionReq, Report,
};
use crate::{constants::DapMediaType, error::DapAbort, messages::TaskId, DapVersion};

pub mod resource {
/// Aggregation job resource.
pub use crate::messages::AggregationJobId;
/// Collection job resource.
pub use crate::messages::CollectionJobId;

/// Undefined (or undetermined) resource.
///
/// The resource of a DAP request is undefined if there is not a unique object (in the context
/// of a DAP task) that the request pertains to. For example:
///
/// * The Client->Aggregator request for the HPKE config or to upload a report
/// * The Leader->Helper request for an aggregate share
#[derive(Clone, Debug, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Copy)]
pub struct None;
pub trait RequestBody {
type ResourceId;
}

/// A poll request has no body, but requires a `CollectionJobId`.
pub struct CollectionPollReq;

macro_rules! impl_req_body {
($($body:tt | $id:tt)*) => {
$(impl RequestBody for $body {
type ResourceId = $id;
})*
};
}

impl_req_body! {
// body type | id type
// --------------------- | ----------------
Report | ()
AggregationJobInitReq | AggregationJobId
AggregateShareReq | ()
CollectionReq | CollectionJobId
CollectionPollReq | CollectionJobId
() | ()
}

/// Fields common to all DAP requests.
Expand Down Expand Up @@ -59,23 +71,23 @@ impl DapRequestMeta {
/// - [`None`](resource::None)

Check failure on line 71 in crates/daphne/src/messages/request.rs

View workflow job for this annotation

GitHub Actions / Testing

unresolved link to `resource::None`
#[derive(Debug)]
#[cfg_attr(test, derive(Default))]
pub struct DapRequest<P, R = resource::None> {
pub struct DapRequest<B: RequestBody> {
pub meta: DapRequestMeta,

/// The resource with which this request is associated.
pub resource_id: R,
pub resource_id: B::ResourceId,

/// Request payload.
pub payload: P,
pub payload: B,
}

impl<P, R> AsRef<DapRequestMeta> for DapRequest<P, R> {
impl<B: RequestBody> AsRef<DapRequestMeta> for DapRequest<B> {
fn as_ref(&self) -> &DapRequestMeta {
&self.meta
}
}

impl<P, R> Deref for DapRequest<P, R> {
impl<B: RequestBody> Deref for DapRequest<B> {
type Target = DapRequestMeta;
fn deref(&self) -> &Self::Target {
&self.meta
Expand Down
7 changes: 3 additions & 4 deletions crates/daphne/src/roles/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ use crate::{
constants::DapMediaType,
error::DapAbort,
messages::{
constant_time_eq, request::resource, AggregateShare, AggregateShareReq,
AggregationJobInitReq, AggregationJobResp, PartialBatchSelector, TaskId, TransitionFailure,
TransitionVar,
constant_time_eq, AggregateShare, AggregateShareReq, AggregationJobInitReq,
AggregationJobResp, PartialBatchSelector, TaskId, TransitionFailure, TransitionVar,
},
metrics::{DaphneMetrics, DaphneRequestType, ReportStatus},
protocol::aggregator::{InitializedReport, ReplayProtection, ReportProcessedStatus},
Expand All @@ -27,7 +26,7 @@ pub trait DapHelper: DapAggregator {}

pub async fn handle_agg_job_init_req<A: DapHelper + Sync>(
aggregator: &A,
req: DapRequest<AggregationJobInitReq, resource::AggregationJobId>,
req: DapRequest<AggregationJobInitReq>,
replay_protection: ReplayProtection,
) -> Result<DapResponse, DapError> {
let task_id = req.task_id;
Expand Down
8 changes: 4 additions & 4 deletions crates/daphne/src/roles/leader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ use crate::{
error::DapAbort,
fatal_error,
messages::{
request::resource, taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq,
AggregationJobId, AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection,
CollectionJobId, CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId,
taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobId,
AggregationJobResp, Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId,
CollectionReq, Interval, PartialBatchSelector, Query, Report, TaskId,
},
metrics::{DaphneRequestType, ReportStatus},
roles::resolve_task_config,
Expand Down Expand Up @@ -236,7 +236,7 @@ pub async fn handle_upload_req<A: DapLeader>(
/// poll later on to get the collection.
pub async fn handle_coll_job_req<A: DapLeader>(
aggregator: &A,
req: &DapRequest<CollectionReq, resource::CollectionJobId>,
req: &DapRequest<CollectionReq>,
) -> Result<(), DapError> {
let global_config = aggregator.get_global_config().await?;
let now = aggregator.get_current_time();
Expand Down
39 changes: 18 additions & 21 deletions crates/daphne/src/roles/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ mod test {
constants::DapMediaType,
hpke::{HpkeKemId, HpkeProvider, HpkeReceiverConfig},
messages::{
request::resource, AggregateShareReq, AggregationJobId, AggregationJobInitReq,
request::RequestBody, AggregateShareReq, AggregationJobId, AggregationJobInitReq,
AggregationJobResp, BatchId, BatchSelector, Collection, CollectionJobId, CollectionReq,
Extension, HpkeCiphertext, Interval, PartialBatchSelector, Query, Report, TaskId, Time,
TransitionFailure, TransitionVar,
Expand Down Expand Up @@ -389,7 +389,7 @@ mod test {
&self,
report: Report,
task_id: &TaskId,
) -> DapRequest<Report, resource::None> {
) -> DapRequest<Report> {
let task_config = self.leader.unchecked_get_task_config(task_id).await;
let version = task_config.version;

Expand All @@ -400,7 +400,7 @@ mod test {
task_id: *task_id,
..Default::default()
},
resource_id: resource::None,
resource_id: Default::default(),
payload: report,
}
}
Expand All @@ -409,7 +409,7 @@ mod test {
&self,
query: Query,
task_id: &TaskId,
) -> DapRequest<CollectionReq, resource::CollectionJobId> {
) -> DapRequest<CollectionReq> {
self.gen_test_coll_job_req_for_collection(query, DapAggregationParam::Empty, task_id)
.await
}
Expand All @@ -419,7 +419,7 @@ mod test {
query: Query,
agg_param: DapAggregationParam,
task_id: &TaskId,
) -> DapRequest<CollectionReq, resource::CollectionJobId> {
) -> DapRequest<CollectionReq> {
let task_config = self.leader.unchecked_get_task_config(task_id).await;

Self::collector_req(
Expand All @@ -438,10 +438,7 @@ mod test {
task_id: &TaskId,
agg_param: DapAggregationParam,
reports: Vec<Report>,
) -> (
DapAggregationJobState,
DapRequest<AggregationJobInitReq, resource::AggregationJobId>,
) {
) -> (DapAggregationJobState, DapRequest<AggregationJobInitReq>) {
let mut rng = thread_rng();
let task_config = self.leader.unchecked_get_task_config(task_id).await;
let part_batch_sel = match task_config.query {
Expand Down Expand Up @@ -520,13 +517,13 @@ mod test {
.unwrap()
}

pub fn leader_req<P, R>(
pub fn leader_req<B: RequestBody>(
task_id: &TaskId,
task_config: &DapTaskConfig,
agg_job_id: R,
agg_job_id: B::ResourceId,
media_type: DapMediaType,
payload: P,
) -> DapRequest<P, R> {
payload: B,
) -> DapRequest<B> {
DapRequest {
meta: DapRequestMeta {
version: task_config.version,
Expand All @@ -539,12 +536,12 @@ mod test {
}
}

pub fn collector_req<C>(
pub fn collector_req(
task_id: &TaskId,
task_config: &DapTaskConfig,
media_type: DapMediaType,
payload: C,
) -> DapRequest<C, resource::CollectionJobId> {
payload: CollectionReq,
) -> DapRequest<CollectionReq> {
let mut rng = thread_rng();
let coll_job_id = CollectionJobId(rng.gen());

Expand Down Expand Up @@ -667,7 +664,7 @@ mod test {
let req = Test::leader_req(
&t.time_interval_task_id,
&task_config,
resource::None,
(),
DapMediaType::AggregateShareReq,
AggregateShareReq {
batch_sel: BatchSelector::FixedSizeByBatchId {
Expand All @@ -693,7 +690,7 @@ mod test {
let req = Test::leader_req(
&t.fixed_size_task_id,
&task_config,
resource::None,
(),
DapMediaType::AggregateShareReq,
AggregateShareReq {
batch_sel: BatchSelector::FixedSizeByBatchId {
Expand Down Expand Up @@ -875,7 +872,7 @@ mod test {
task_id: TaskId([0; 32]),
..Default::default()
},
resource_id: resource::None,
resource_id: (),
payload: report_invalid_task_id,
};

Expand Down Expand Up @@ -904,7 +901,7 @@ mod test {
task_id: *task_id,
..Default::default()
},
resource_id: resource::None,
resource_id: (),
payload: report.clone(),
};

Expand Down Expand Up @@ -1464,7 +1461,7 @@ mod test {
task_id,
taskprov_advertisement: Some(taskprov_advertisement.clone()),
},
resource_id: resource::None,
resource_id: (),
payload: report,
};
leader::handle_upload_req(&*t.leader, req).await.unwrap();
Expand Down
12 changes: 6 additions & 6 deletions crates/daphne/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use crate::{
fatal_error,
hpke::{HpkeConfig, HpkeKemId, HpkeProvider, HpkeReceiverConfig},
messages::{
self, request::resource, AggregationJobId, AggregationJobInitReq, AggregationJobResp,
Base64Encode, BatchId, BatchSelector, Collection, CollectionJobId, HpkeCiphertext,
Interval, PartialBatchSelector, Report, ReportId, TaskId, Time,
self, AggregationJobId, AggregationJobInitReq, AggregationJobResp, Base64Encode, BatchId,
BatchSelector, Collection, CollectionJobId, HpkeCiphertext, Interval, PartialBatchSelector,
Report, ReportId, TaskId, Time,
},
metrics::{prometheus::DaphnePromMetrics, DaphneMetrics},
roles::{
Expand Down Expand Up @@ -972,7 +972,7 @@ impl DapLeader for InMemoryAggregator {
&**self.peer.as_ref().expect("peer not configured"),
DapRequest {
payload: re_encode(&meta, payload),
resource_id: resource::AggregationJobId::try_from_base64url(
resource_id: AggregationJobId::try_from_base64url(
url.path().split('/').last().unwrap(),
)
.unwrap(),
Expand All @@ -986,7 +986,7 @@ impl DapLeader for InMemoryAggregator {
&**self.peer.as_ref().expect("peer not configured"),
DapRequest {
payload: re_encode(&meta, payload),
resource_id: resource::None,
resource_id: (),
meta,
},
)
Expand All @@ -1010,7 +1010,7 @@ impl DapLeader for InMemoryAggregator {
&**self.peer.as_ref().expect("peer not configured"),
DapRequest {
payload: re_encode(&meta, payload),
resource_id: resource::AggregationJobId::try_from_base64url(
resource_id: AggregationJobId::try_from_base64url(
url.path().split('/').last().unwrap(),
)
.unwrap(),
Expand Down

0 comments on commit 6c1d3d8

Please sign in to comment.