From eca027aa51f82979f8191dedaf48cd07e7adf9be Mon Sep 17 00:00:00 2001 From: Dsaquel Date: Fri, 23 Feb 2024 16:43:04 +0100 Subject: [PATCH] add same custom http client for jwks refresh --- jwt-authorizer/src/authorizer.rs | 6 +++--- jwt-authorizer/src/jwks/key_store_manager.rs | 22 +++++++++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/jwt-authorizer/src/authorizer.rs b/jwt-authorizer/src/authorizer.rs index d80963c..81f0462 100644 --- a/jwt-authorizer/src/authorizer.rs +++ b/jwt-authorizer/src/authorizer.rs @@ -201,7 +201,7 @@ where } KeySourceType::Jwks(url) => { let jwks_url = Url::parse(url.as_str()).map_err(|e| InitError::JwksUrlError(e.to_string()))?; - let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); + let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default(), http_client); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, @@ -210,10 +210,10 @@ where } } KeySourceType::Discovery(issuer_url) => { - let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client).await?) + let jwks_url = Url::parse(&oidc::discover_jwks(issuer_url.as_str(), http_client.clone()).await?) .map_err(|e| InitError::JwksUrlError(e.to_string()))?; - let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default()); + let key_store_manager = KeyStoreManager::new(jwks_url, refresh.unwrap_or_default(), http_client); Authorizer { key_source: KeySource::KeyStoreSource(key_store_manager), claims_checker, diff --git a/jwt-authorizer/src/jwks/key_store_manager.rs b/jwt-authorizer/src/jwks/key_store_manager.rs index 2ff3edf..c04681a 100644 --- a/jwt-authorizer/src/jwks/key_store_manager.rs +++ b/jwt-authorizer/src/jwks/key_store_manager.rs @@ -1,5 +1,5 @@ use jsonwebtoken::{jwk::JwkSet, Algorithm}; -use reqwest::Url; +use reqwest::{Client, Url}; use std::{ sync::Arc, time::{Duration, Instant}, @@ -55,6 +55,7 @@ pub struct KeyStoreManager { /// in case of fail loading (error or key not found), minimal interval refresh: Refresh, keystore: Arc>, + client: Option, } pub struct KeyStore { @@ -67,7 +68,7 @@ pub struct KeyStore { } impl KeyStoreManager { - pub(crate) fn new(key_url: Url, refresh: Refresh) -> KeyStoreManager { + pub(crate) fn new(key_url: Url, refresh: Refresh, client: Option) -> KeyStoreManager { KeyStoreManager { key_url, refresh, @@ -76,6 +77,7 @@ impl KeyStoreManager { load_time: None, fail_time: None, })), + client, } } @@ -85,7 +87,7 @@ impl KeyStoreManager { let key = match self.refresh.strategy { RefreshStrategy::Interval => { if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[]).await?; + ks_gard.refresh(&self.key_url, &[], self.client.as_ref()).await?; } ks_gard.get_key(header)? } @@ -95,7 +97,7 @@ impl KeyStoreManager { if let Some(jwk) = jwk_opt { jwk } else if ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[("kid", kid)]).await?; + ks_gard.refresh(&self.key_url, &[("kid", kid)], self.client.as_ref()).await?; ks_gard.find_kid(kid).ok_or_else(|| AuthError::InvalidKid(kid.to_owned()))? } else { return Err(AuthError::InvalidKid(kid.to_owned())); @@ -112,6 +114,7 @@ impl KeyStoreManager { "alg", &serde_json::to_string(&header.alg).map_err(|_| AuthError::InvalidKeyAlg(header.alg))?, )], + self.client.as_ref(), ) .await?; ks_gard @@ -127,7 +130,7 @@ impl KeyStoreManager { // if jwks endpoint is down for the loading, respect retry_interval && ks_gard.can_refresh(self.refresh.refresh_interval, self.refresh.retry_interval) { - ks_gard.refresh(&self.key_url, &[]).await?; + ks_gard.refresh(&self.key_url, &[], self.client.as_ref()).await?; } ks_gard.get_key(header)? } @@ -151,8 +154,10 @@ impl KeyStore { } } - async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)]) -> Result<(), AuthError> { - reqwest::Client::new() + async fn refresh(&mut self, key_url: &Url, qparam: &[(&str, &str)], client: Option<&Client>) -> Result<(), AuthError> { + client + .cloned() + .unwrap_or_default() .get(key_url.as_ref()) .query(qparam) .send() @@ -372,6 +377,7 @@ mod tests { refresh_interval: Duration::from_millis(10), retry_interval: Duration::from_millis(5), }, + None, ); // 1st RELOAD @@ -419,6 +425,7 @@ mod tests { refresh_interval: Duration::from_millis(10), retry_interval: Duration::from_millis(5), }, + None, ); // STEP 1: initial (lazy) reloading @@ -477,6 +484,7 @@ mod tests { strategy: RefreshStrategy::NoRefresh, ..Default::default() }, + None, ); // STEP 1: initial (lazy) reloading