Skip to content

Commit

Permalink
Retry all requests in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mendess committed Jan 15, 2025
1 parent d64d14a commit 13833c5
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 259 deletions.
25 changes: 14 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

141 changes: 64 additions & 77 deletions crates/dapf/src/functions/helper.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
// Copyright (c) 2024 Cloudflare, Inc. All rights reserved.
// SPDX-License-Identifier: BSD-3-Clause

use anyhow::{anyhow, Context as _};
use anyhow::Context as _;
use daphne::{
constants::DapMediaType,
error::aborts::ProblemDetails,
messages::{
taskprov::TaskprovAdvertisement, AggregateShareReq, AggregationJobInitReq,
taskprov::TaskprovAdvertisement, AggregateShare, AggregateShareReq, AggregationJobInitReq,
AggregationJobResp,
},
DapVersion,
};
use daphne_service_utils::{bearer_token::BearerToken, http_headers};
use prio::codec::{ParameterizedDecode as _, ParameterizedEncode as _};
use prio::codec::ParameterizedEncode as _;
use reqwest::header;
use url::Url;

use crate::HttpClient;

use super::response_to_anyhow;
use super::retry_and_decode;

impl HttpClient {
pub async fn submit_aggregation_job_init_req(
Expand All @@ -28,45 +27,47 @@ impl HttpClient {
version: DapVersion,
opts: Options<'_>,
) -> anyhow::Result<AggregationJobResp> {
let resp = self
.put(url)
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregationJobInitReq media type is not defined for {version}")
})?,
version,
opts,
)?)
.send()
.await
.context("sending AggregationJobInitReq")?;
if resp.status() == 400 {
let text = resp.text().await?;
let problem_details: ProblemDetails =
serde_json::from_str(&text).with_context(|| {
format!("400 Bad Request: failed to parse problem details document: {text:?}")
})?;
Err(anyhow!("400 Bad Request: {problem_details:?}"))
} else if resp.status() == 500 {
Err(anyhow::anyhow!(
"500 Internal Server Error: {}",
resp.text().await?
))
} else if !resp.status().is_success() {
Err(response_to_anyhow(resp).await).context("while running an AggregationJobInitReq")
} else {
AggregationJobResp::get_decoded_with_param(
&version,
&resp
.bytes()
.await
.context("transfering bytes from the AggregateInitReq")?,
)
.with_context(|| "failed to parse response to AggregateInitReq from Helper")
}
retry_and_decode(&version, || async {
self.put(url.clone())
.body(agg_job_init_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregationJobInitReq media type is not defined for {version}")
})?,
version,
opts,
)?)
.send()
.await
.context("sending AggregationJobInitReq")
})
.await
}

pub async fn poll_aggregation_job_init(
&self,
url: Url,
version: DapVersion,
opts: Options<'_>,
) -> anyhow::Result<AggregationJobResp> {
retry_and_decode(&version, || async {
self.get(url.clone())
.headers(construct_request_headers(
DapMediaType::AggregationJobInitReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregationJobInitReq media type is not defined for {version}")
})?,
version,
opts,
)?)
.send()
.await
.context("polling aggregation job init req")
})
.await
}

pub async fn get_aggregate_share(
Expand All @@ -75,42 +76,28 @@ impl HttpClient {
agg_share_req: AggregateShareReq,
version: DapVersion,
opts: Options<'_>,
) -> anyhow::Result<()> {
let resp = self
.post(url)
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregateShareReq media type is not defined for {version}")
})?,
version,
opts,
)?)
.send()
.await
.context("sending AggregateShareReq")?;
if resp.status() == 400 {
let problem_details: ProblemDetails = serde_json::from_slice(
&resp
.bytes()
.await
.context("transfering bytes for AggregateShareReq")?,
)
.with_context(|| "400 Bad Request: failed to parse problem details document")?;
Err(anyhow!("400 Bad Request: {problem_details:?}"))
} else if resp.status() == 500 {
Err(anyhow!("500 Internal Server Error: {}", resp.text().await?))
} else if !resp.status().is_success() {
Err(response_to_anyhow(resp).await).context("while running an AggregateShareReq")
} else {
Ok(())
}
) -> anyhow::Result<AggregateShare> {
retry_and_decode(&(), || async {
self.post(url.clone())
.body(agg_share_req.get_encoded_with_param(&version).unwrap())
.headers(construct_request_headers(
DapMediaType::AggregateShareReq
.as_str_for_version(version)
.with_context(|| {
format!("AggregateShareReq media type is not defined for {version}")
})?,
version,
opts,
)?)
.send()
.await
.context("sending AggregateShareReq")
})
.await
}
}

#[derive(Default, Debug)]
#[derive(Default, Debug, Clone, Copy)]
pub struct Options<'s> {
pub taskprov_advertisement: Option<&'s TaskprovAdvertisement>,
pub bearer_token: Option<&'s BearerToken>,
Expand Down
66 changes: 35 additions & 31 deletions crates/dapf/src/functions/hpke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use x509_parser::pem::Pem;

use crate::HttpClient;

use super::response_to_anyhow;
use super::retry;

impl HttpClient {
pub async fn get_hpke_config(
Expand All @@ -22,36 +22,40 @@ impl HttpClient {
certificate_file: Option<&Path>,
) -> anyhow::Result<HpkeConfigList> {
let url = base_url.join("hpke_config")?;
let resp = self
.get(url.as_str())
.send()
.await
.with_context(|| "request failed")?;
if !resp.status().is_success() {
return Err(response_to_anyhow(resp).await);
}
let maybe_signature = resp.headers().get(http_headers::HPKE_SIGNATURE).cloned();
let hpke_config_bytes = resp.bytes().await.context("failed to read hpke config")?;
if let Some(cert_path) = certificate_file {
let cert = std::fs::read_to_string(cert_path).context("reading the certificate")?;
let Some(signature) = maybe_signature else {
anyhow::bail!("Aggregator did not sign its response");
};
let signature_bytes =
decode_base64url_vec(signature.as_bytes()).context("decoding the signature")?;
let (cert_pem, _bytes_read) =
Pem::read(Cursor::new(cert.as_bytes())).context("reading PEM certificate")?;
let cert = EndEntityCert::try_from(cert_pem.contents.as_ref())
.map_err(|e| anyhow!("{e:?}")) // webpki::Error does not implement std::error::Error
.context("parsing PEM certificate")?;
retry(
|| async {
self.get(url.as_str())
.send()
.await
.with_context(|| "request failed")
},
|resp| async {
let maybe_signature = resp.headers().get(http_headers::HPKE_SIGNATURE).cloned();
let hpke_config_bytes = resp.bytes().await.context("failed to read hpke config")?;
if let Some(cert_path) = certificate_file {
let cert =
std::fs::read_to_string(cert_path).context("reading the certificate")?;
let Some(signature) = maybe_signature else {
anyhow::bail!("Aggregator did not sign its response");
};
let signature_bytes = decode_base64url_vec(signature.as_bytes())
.context("decoding the signature")?;
let (cert_pem, _bytes_read) = Pem::read(Cursor::new(cert.as_bytes()))
.context("reading PEM certificate")?;
let cert = EndEntityCert::try_from(cert_pem.contents.as_ref())
.map_err(|e| anyhow!("{e:?}")) // webpki::Error does not implement std::error::Error
.context("parsing PEM certificate")?;

cert.verify_signature(
&ECDSA_P256_SHA256,
&hpke_config_bytes,
signature_bytes.as_ref(),
)
.map_err(|e| anyhow!("signature not verified: {}", e.to_string()))?;
}
Ok(HpkeConfigList::get_decoded(&hpke_config_bytes)?)
cert.verify_signature(
&ECDSA_P256_SHA256,
&hpke_config_bytes,
signature_bytes.as_ref(),
)
.map_err(|e| anyhow!("signature not verified: {}", e.to_string()))?;
}
Ok(HpkeConfigList::get_decoded(&hpke_config_bytes)?)
},
)
.await
}
}
Loading

0 comments on commit 13833c5

Please sign in to comment.