Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make more use of DapRole enums #728

Merged
merged 4 commits into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/dapf/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use dapf::{
HttpClient,
};
use daphne::{
constants::DapAggregatorRole,
hpke::HpkeReceiverConfig,
messages::{
self, encode_base64url, BatchSelector, CollectionReq, PartialBatchSelector, Query, TaskId,
Expand All @@ -23,7 +24,6 @@ use daphne::{
use daphne_service_utils::{
bearer_token::BearerToken,
test_route_types::{InternalTestAddTask, InternalTestVdaf},
DapRole,
};
use prio::codec::{ParameterizedDecode, ParameterizedEncode};
use rand::{thread_rng, Rng};
Expand Down Expand Up @@ -236,7 +236,7 @@ enum TestAction {
#[arg(long)]
collector_auth_token: Option<String>,
#[arg(long)]
role: Option<DapRole>,
role: Option<DapAggregatorRole>,
#[arg(long)]
query: Option<CliDapQueryConfig>,
#[arg(long)]
Expand Down Expand Up @@ -866,11 +866,11 @@ async fn handle_test_routes(action: TestAction, http_client: HttpClient) -> anyh
"leader auth token",
)?,
collector_authentication_token: match role {
DapRole::Leader => Some(use_or_request_from_user(
DapAggregatorRole::Leader => Some(use_or_request_from_user(
collector_auth_token,
"collector auth token",
)?),
DapRole::Helper => None,
DapAggregatorRole::Helper => None,
},
role,
vdaf_verify_key,
Expand Down
8 changes: 4 additions & 4 deletions crates/daphne-server/examples/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
use std::path::PathBuf;

use clap::Parser;
use daphne::constants::DapAggregatorRole;
use daphne_server::{
config::DaphneServiceConfig, metrics::DaphnePromServiceMetrics, router, App, StorageProxyConfig,
};
use daphne_service_utils::DapRole;
use serde::{Deserialize, Serialize};
use tokio::net::TcpListener;
use tracing_subscriber::EnvFilter;
Expand Down Expand Up @@ -47,8 +47,8 @@ impl TryFrom<Args> for Config {
config::Value::new(
Some(&String::from("args.role")),
match role {
DapRole::Leader => "leader",
DapRole::Helper => "helper",
DapAggregatorRole::Leader => "leader",
DapAggregatorRole::Helper => "helper",
},
)
}),
Expand Down Expand Up @@ -81,7 +81,7 @@ struct Args {
// --- command line overridable parameters ---
/// One of `leader` or `helper`.
#[arg(short, long)]
role: Option<DapRole>,
role: Option<DapAggregatorRole>,
/// The port to listen on.
#[arg(short, long)]
port: Option<u16>,
Expand Down
5 changes: 3 additions & 2 deletions crates/daphne-server/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
// SPDX-License-Identifier: BSD-3-Clause

use daphne::{
constants::DapAggregatorRole,
hpke::{HpkeConfig, HpkeReceiverConfig},
DapGlobalConfig, DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, DapRole};
use daphne_service_utils::bearer_token::BearerToken;
use p256::ecdsa::SigningKey;
use serde::{Deserialize, Serialize};
use url::Url;
Expand Down Expand Up @@ -45,7 +46,7 @@ pub type HpkeRecieverConfigList = Vec<HpkeReceiverConfig>;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct DaphneServiceConfig {
/// Indicates the role the service should play.
pub role: DapRole,
pub role: DapAggregatorRole,

/// Global DAP configuration.
#[serde(flatten)]
Expand Down
20 changes: 10 additions & 10 deletions crates/daphne-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use std::sync::Arc;
use config::{DaphneServiceConfig, PeerBearerToken};
use daphne::{
audit_log::{AuditLog, NoopAuditLog},
constants::DapRole,
fatal_error,
messages::{Base64Encode, TaskId},
roles::{leader::in_memory_leader::InMemoryLeaderState, DapAggregator},
DapError, DapSender,
DapError,
};
use daphne_service_utils::bearer_token::BearerToken;
use either::Either::{self, Left, Right};
Expand Down Expand Up @@ -42,15 +43,14 @@ mod storage_proxy_connection;
/// ```
/// use std::num::NonZeroUsize;
/// use url::Url;
/// use daphne::{DapGlobalConfig, hpke::HpkeKemId, DapVersion};
/// use daphne::{DapGlobalConfig, constants::DapAggregatorRole, hpke::HpkeKemId, DapVersion};
/// use daphne_server::{
/// App,
/// router,
/// StorageProxyConfig,
/// metrics::DaphnePromServiceMetrics,
/// config::DaphneServiceConfig,
/// };
/// use daphne_service_utils::DapRole;
///
/// let storage_proxy_settings = StorageProxyConfig {
/// url: Url::parse("http://example.com").unwrap(),
Expand All @@ -66,7 +66,7 @@ mod storage_proxy_connection;
/// default_num_agg_span_shards: NonZeroUsize::new(2).unwrap(),
/// };
/// let service_config = DaphneServiceConfig {
/// role: DapRole::Helper,
/// role: DapAggregatorRole::Helper,
/// global,
/// base_url: None,
/// taskprov: None,
Expand All @@ -77,7 +77,7 @@ mod storage_proxy_connection;
/// };
/// let app = App::new(storage_proxy_settings, daphne_service_metrics, service_config)?;
///
/// let router = router::new(DapRole::Helper, app);
/// let router = router::new(DapAggregatorRole::Helper, app);
///
/// # Ok::<(), daphne::DapError>(())
/// ```
Expand Down Expand Up @@ -114,7 +114,7 @@ impl router::DaphneService for App {
async fn check_bearer_token(
&self,
presented_token: &BearerToken,
sender: DapSender,
sender: DapRole,
task_id: TaskId,
is_taskprov: bool,
) -> Result<(), Either<String, DapError>> {
Expand All @@ -131,16 +131,16 @@ impl router::DaphneService for App {
.filter(|_| self.service_config.taskprov.is_some() && is_taskprov)
{
match (&taskprov.peer_auth, sender) {
(PeerBearerToken::Leader { expected_token }, DapSender::Leader)
| (PeerBearerToken::Collector { expected_token }, DapSender::Collector)
(PeerBearerToken::Leader { expected_token }, DapRole::Leader)
| (PeerBearerToken::Collector { expected_token }, DapRole::Collector)
if expected_token == presented_token =>
{
Ok(())
}
(PeerBearerToken::Leader { .. }, DapSender::Collector) => Err(Right(fatal_error!(
(PeerBearerToken::Leader { .. }, DapRole::Collector) => Err(Right(fatal_error!(
err = "expected a leader sender but got a collector sender"
))),
(PeerBearerToken::Collector { .. }, DapSender::Leader) => Err(Right(fatal_error!(
(PeerBearerToken::Collector { .. }, DapRole::Leader) => Err(Right(fatal_error!(
err = "expected a collector sender but got a leader sender"
))),
_ => reject(format_args!("using taskprov")),
Expand Down
7 changes: 3 additions & 4 deletions crates/daphne-server/src/roles/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use daphne::{
audit_log::AuditLog,
error::DapAbort,
fatal_error,
hpke::{self, HpkeConfig, HpkeProvider, HpkeReceiverConfig},
hpke::{self, info_and_aad, HpkeConfig, HpkeProvider, HpkeReceiverConfig},
messages::{self, BatchId, BatchSelector, HpkeCiphertext, TaskId, Time},
metrics::DaphneMetrics,
roles::{
Expand Down Expand Up @@ -363,11 +363,10 @@ pub struct HpkeDecrypter(Marc<Vec<HpkeReceiverConfig>>);
impl hpke::HpkeDecrypter for HpkeDecrypter {
fn hpke_decrypt(
&self,
info: &[u8],
aad: &[u8],
info: impl info_and_aad::InfoAndAad,
ciphertext: &HpkeCiphertext,
) -> Result<Vec<u8>, DapError> {
self.0.hpke_decrypt(info, aad, ciphertext)
self.0.hpke_decrypt(info, ciphertext)
}
}

Expand Down
4 changes: 2 additions & 2 deletions crates/daphne-server/src/roles/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{borrow::Cow, time::Instant};

use axum::{async_trait, http::Method};
use daphne::{
constants::DapMediaType,
constants::{DapMediaType, DapRole},
error::DapAbort,
fatal_error,
messages::{BatchId, BatchSelector, Collection, CollectionJobId, Report, TaskId},
Expand Down Expand Up @@ -171,7 +171,7 @@ impl crate::App {
}
} else if let Some(bearer_token) = self
.bearer_tokens()
.get(daphne::DapSender::Leader, meta.task_id)
.get(DapRole::Leader, meta.task_id)
.await
.map_err(|e| fatal_error!(err = ?e, "failed to get leader bearer token"))?
{
Expand Down
30 changes: 15 additions & 15 deletions crates/daphne-server/src/roles/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use daphne::{messages::TaskId, DapSender, ReplayProtection};
use daphne::{constants::DapRole, messages::TaskId, ReplayProtection};
use daphne_service_utils::bearer_token::BearerToken;

use crate::storage_proxy_connection::{
Expand Down Expand Up @@ -54,12 +54,12 @@ impl BearerTokens<'_> {
#[cfg(feature = "test-utils")]
pub async fn put_if_not_exists(
&self,
role: DapSender,
sender: DapRole,
task_id: TaskId,
token: BearerToken,
) -> Result<Option<BearerToken>, storage_proxy_connection::Error> {
self.kv
.put_if_not_exists::<kv::prefix::KvBearerToken>(&(role, task_id).into(), token)
.put_if_not_exists::<kv::prefix::KvBearerToken>(&(sender, task_id).into(), token)
.await
}

Expand All @@ -72,13 +72,13 @@ impl BearerTokens<'_> {
/// - `Err(error)` if any io error occurs while fetching
pub async fn matches(
&self,
role: DapSender,
sender: DapRole,
task_id: TaskId,
token: &BearerToken,
) -> Result<bool, Marc<storage_proxy_connection::Error>> {
self.kv
.peek::<kv::prefix::KvBearerToken, _, _>(
&(role, task_id).into(),
&(sender, task_id).into(),
&kv::KvGetOptions {
cache_not_found: false,
},
Expand All @@ -90,12 +90,12 @@ impl BearerTokens<'_> {

pub async fn get(
&self,
role: DapSender,
sender: DapRole,
task_id: TaskId,
) -> Result<Option<BearerToken>, Marc<storage_proxy_connection::Error>> {
self.kv
.get_cloned::<kv::prefix::KvBearerToken>(
&(role, task_id).into(),
&(sender, task_id).into(),
&kv::KvGetOptions {
cache_not_found: false,
},
Expand All @@ -107,17 +107,17 @@ impl BearerTokens<'_> {
#[cfg(feature = "test-utils")]
mod test_utils {
use daphne::{
constants::{DapAggregatorRole, DapRole},
fatal_error,
hpke::{HpkeConfig, HpkeReceiverConfig},
messages::decode_base64url_vec,
roles::DapAggregator,
vdaf::{Prio3Config, VdafConfig},
DapError, DapQueryConfig, DapSender, DapTaskConfig, DapVersion,
DapError, DapQueryConfig, DapTaskConfig, DapVersion,
};
use daphne_service_utils::{
bearer_token::BearerToken,
test_route_types::{InternalTestAddTask, InternalTestEndpointForTask},
DapRole,
};
use prio::codec::Decode;
use std::num::NonZeroUsize;
Expand Down Expand Up @@ -236,7 +236,7 @@ mod test_utils {
let token = BearerToken::from(cmd.leader_authentication_token);
if self
.bearer_tokens()
.put_if_not_exists(DapSender::Leader, cmd.task_id, token)
.put_if_not_exists(DapRole::Leader, cmd.task_id, token)
.await
.is_err()
{
Expand All @@ -248,11 +248,11 @@ mod test_utils {

// Collector authentication token.
match (cmd.role, cmd.collector_authentication_token) {
(DapRole::Leader, Some(token_string)) => {
(DapAggregatorRole::Leader, Some(token_string)) => {
let token = BearerToken::from(token_string);
if self
.bearer_tokens()
.put_if_not_exists(DapSender::Collector, cmd.task_id, token)
.put_if_not_exists(DapRole::Collector, cmd.task_id, token)
.await
.is_err()
{
Expand All @@ -262,13 +262,13 @@ mod test_utils {
)));
}
}
(DapRole::Leader, None) => {
(DapAggregatorRole::Leader, None) => {
return Err(fatal_error!(
err = "command failed: missing collector authentication token",
))
}
(DapRole::Helper, None) => (),
(DapRole::Helper, Some(..)) => {
(DapAggregatorRole::Helper, None) => (),
(DapAggregatorRole::Helper, Some(..)) => {
return Err(fatal_error!(
err = "command failed: unexpected collector authentication token",
));
Expand Down
16 changes: 8 additions & 8 deletions crates/daphne-server/src/router/extractor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ pub mod dap_sender {
pub const FROM_HELPER: DapSender = 1 << 2;
pub const FROM_LEADER: DapSender = 1 << 3;

pub const fn to_enum(id: DapSender) -> daphne::DapSender {
pub const fn to_enum(id: DapSender) -> daphne::constants::DapRole {
match id {
FROM_CLIENT => daphne::DapSender::Client,
FROM_COLLECTOR => daphne::DapSender::Collector,
FROM_HELPER => daphne::DapSender::Helper,
FROM_LEADER => daphne::DapSender::Leader,
FROM_CLIENT => daphne::constants::DapRole::Client,
FROM_COLLECTOR => daphne::constants::DapRole::Collector,
FROM_HELPER => daphne::constants::DapRole::Helper,
FROM_LEADER => daphne::constants::DapRole::Leader,
_ => panic!("invalid dap sender. Please specify a valid dap_sender from the crate::extractor::dap_sender module"),
}
}
Expand Down Expand Up @@ -344,14 +344,14 @@ mod test {
};
use daphne::{
async_test_versions,
constants::DapMediaType,
constants::{DapMediaType, DapRole},
messages::{
request::{CollectionPollReq, RequestBody},
taskprov::TaskprovAdvertisement,
AggregationJobId, AggregationJobInitReq, Base64Encode, CollectionJobId, CollectionReq,
TaskId,
},
DapError, DapRequest, DapRequestMeta, DapSender, DapVersion,
DapError, DapRequest, DapRequestMeta, DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use either::Either::{self, Left};
Expand Down Expand Up @@ -409,7 +409,7 @@ mod test {
async fn check_bearer_token(
&self,
token: &BearerToken,
_sender: DapSender,
_sender: DapRole,
_task_id: TaskId,
_is_taskprov: bool,
) -> Result<(), Either<String, DapError>> {
Expand Down
Loading
Loading