diff --git a/nym-vpn-core/crates/nym-vpn-api-client/src/client.rs b/nym-vpn-core/crates/nym-vpn-api-client/src/client.rs index 4d8fd4bb37..9014eaf7d9 100644 --- a/nym-vpn-core/crates/nym-vpn-api-client/src/client.rs +++ b/nym-vpn-core/crates/nym-vpn-api-client/src/client.rs @@ -1,7 +1,10 @@ // Copyright 2024 - Nym Technologies SA // SPDX-License-Identifier: GPL-3.0-only -use std::time::Duration; +use std::{ + sync::Arc, + time::{Duration, Instant}, +}; use backon::Retryable; use nym_credential_proxy_requests::api::v1::ticketbook::models::PartialVerificationKeysResponse; @@ -9,7 +12,8 @@ use nym_http_api_client::{ ApiClient, Client, HttpClientError, NO_PARAMS, Params, PathSegments, Url, UserAgent, }; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use time::OffsetDateTime; +use time::{Duration as TimeDuration, OffsetDateTime}; +use tokio::sync::RwLock; use crate::{ ResolverOverrides, api_urls_to_urls, @@ -40,11 +44,50 @@ pub(crate) const DEVICE_AUTHORIZATION_HEADER: &str = "x-device-authorization"; // GET requests can unfortunately take a long time over the mixnet pub(crate) const NYM_VPN_API_TIMEOUT: Duration = Duration::from_secs(60); +const SKEW_CACHE_TTL: Duration = Duration::from_secs(4 * 60 * 60); // 4 hours + +#[derive(Debug)] +struct SkewState { + skew: TimeDuration, + expires_at: Instant, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum SkewStatus { + Expired, + Valid(TimeDuration), +} + +impl SkewState { + fn new(skew: TimeDuration, now: Instant) -> Self { + Self { + skew, + expires_at: now + SKEW_CACHE_TTL, + } + } + + fn update(&mut self, skew: TimeDuration, now: Instant) { + self.skew = skew; + self.expires_at = now + SKEW_CACHE_TTL; + } + + fn status(&self, now: Instant) -> SkewStatus { + if self.expires_at > now { + SkewStatus::Valid(self.skew) + } else { + SkewStatus::Expired + } + } +} + #[derive(Clone, Debug)] pub struct VpnApiClient { inner: Client, urls: Vec, user_agent: UserAgent, + skew_state: Arc>>, + #[cfg(test)] + mock_remote_time: Arc>>, } impl VpnApiClient { @@ -65,6 +108,9 @@ impl VpnApiClient { inner, urls, user_agent, + skew_state: Arc::new(RwLock::new(None)), + #[cfg(test)] + mock_remote_time: Arc::new(RwLock::new(None)), }) } @@ -96,6 +142,9 @@ impl VpnApiClient { inner, urls, user_agent, + skew_state: Arc::new(RwLock::new(None)), + #[cfg(test)] + mock_remote_time: Arc::new(RwLock::new(None)), }) } @@ -123,6 +172,11 @@ impl VpnApiClient { } pub async fn get_remote_time(&self) -> Result { + #[cfg(test)] + if let Some(mocked) = self.mock_remote_time.read().await.clone() { + return Ok(mocked); + } + let time_before = OffsetDateTime::now_utc(); let remote_timestamp = self.get_health().await?.timestamp_utc; let time_after = OffsetDateTime::now_utc(); @@ -153,8 +207,55 @@ impl VpnApiClient { } } - async fn sync_with_remote_time(&self) -> Result> { + async fn refresh_skew(&self) -> Result { let remote_time = self.get_remote_time().await?; + let skew = remote_time.local_time_ahead_skew(); + let now = Instant::now(); + + { + let mut skew_state = self.skew_state.write().await; + match skew_state.as_mut() { + Some(state) => state.update(skew, now), + None => *skew_state = Some(SkewState::new(skew, now)), + } + } + + tracing::debug!(skew = ?skew, "Refreshed VPN API time skew"); + + Ok(remote_time) + } + + async fn current_remote_time(&self) -> Result> { + let now = Instant::now(); + let status = { + let state = self.skew_state.read().await; + state.as_ref().map(|state| state.status(now)) + }; + + let cached_remote_time = match status { + Some(SkewStatus::Valid(skew)) => { + tracing::debug!("Valid VPN API time skew"); + let local_time = OffsetDateTime::now_utc(); + let estimated_remote_time = local_time - skew; + + VpnApiTime::from_estimated_remote_time(local_time, estimated_remote_time) + } + Some(SkewStatus::Expired) | None => { + tracing::debug!("VPN API time skew expired or not present, refreshing"); + + self.refresh_skew().await? + } + }; + + Ok(if Self::use_remote_time(cached_remote_time) { + Some(cached_remote_time) + } else { + None + }) + } + + async fn sync_with_remote_time(&self) -> Result> { + let remote_time = self.refresh_skew().await?; if Self::use_remote_time(remote_time) { Ok(Some(remote_time)) @@ -198,7 +299,18 @@ impl VpnApiClient { where T: DeserializeOwned, { - match self.get_query::(path, account, device, None).await { + let jwt = match self.current_remote_time().await { + Ok(remote_time) => remote_time, + Err(err) => { + tracing::debug!( + error = %err, + "Failed to determine cached remote time" + ); + None + } + }; + + match self.get_query::(path, account, device, jwt).await { Ok(response) => Ok(response), Err(err) => { if let HttpClientError::EndpointFailure { error, .. } = &err @@ -364,8 +476,19 @@ impl VpnApiClient { T: DeserializeOwned, B: Serialize, { + let jwt = match self.current_remote_time().await { + Ok(remote_time) => remote_time, + Err(err) => { + tracing::debug!( + error = %err, + "Failed to determine cached remote time" + ); + None + } + }; + match self - .post_query::(path, json_body, account, device, None) + .post_query::(path, json_body, account, device, jwt) .await { Ok(response) => Ok(response), @@ -1281,8 +1404,170 @@ impl VpnApiClient { .map_err(Box::new) .map_err(VpnApiClientError::GetVpnNetworkDetails) } + + // TEST HELPERS + #[cfg(test)] + pub(super) async fn set_mock_remote_time(&self, remote_time: Option) { + let mut guard = self.mock_remote_time.write().await; + *guard = remote_time; + } } fn jwt_error(error: &str) -> bool { error.to_lowercase().contains("jwt") } + +// skew_state tests +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::time::{Duration as StdDuration, Instant}; + + use super::*; + + fn test_user_agent() -> UserAgent { + UserAgent { + application: "vpn-api-client-test".to_string(), + version: "0.0.1".to_string(), + platform: "test-platform".to_string(), + git_commit: "test-commit".to_string(), + } + } + + fn test_client() -> VpnApiClient { + let base_url = "http://localhost"; + let inner = Client::new_url(base_url, Some(StdDuration::from_secs(1))).unwrap(); + let parsed_url = Url::parse(base_url).unwrap(); + + VpnApiClient { + inner, + urls: vec![parsed_url], + user_agent: test_user_agent(), + skew_state: Arc::new(RwLock::new(None)), + mock_remote_time: Arc::new(RwLock::new(None)), + } + } + + fn remote_time_with_skew(seconds: i64) -> VpnApiTime { + let local_time = OffsetDateTime::now_utc(); + let estimated_remote_time = local_time - TimeDuration::seconds(seconds); + VpnApiTime::from_estimated_remote_time(local_time, estimated_remote_time) + } + + #[tokio::test] + async fn current_remote_time_returns_cached_for_valid_skew() { + let client = test_client(); + + { + let mut state = client.skew_state.write().await; + *state = Some(SkewState { + skew: TimeDuration::seconds(120), + expires_at: Instant::now() + StdDuration::from_secs(60), + }); + } + + { + let state = client.skew_state.read().await; + assert!(matches!( + state.as_ref().unwrap().status(Instant::now()), + SkewStatus::Valid(_) + )); + } + + let remote_time = client.current_remote_time().await.unwrap(); + assert!(remote_time.is_some()); + + let state = client.skew_state.read().await; + assert!(matches!( + state.as_ref().unwrap().status(Instant::now()), + SkewStatus::Valid(_) + )); + } + + #[tokio::test] + async fn current_remote_time_returns_none_for_synced_skew() { + let client = test_client(); + + { + let mut state = client.skew_state.write().await; + *state = Some(SkewState { + skew: TimeDuration::seconds(1), + expires_at: Instant::now() + StdDuration::from_secs(60), + }); + } + + { + let state = client.skew_state.read().await; + assert!(matches!( + state.as_ref().unwrap().status(Instant::now()), + SkewStatus::Valid(_) + )); + } + + let remote_time = client.current_remote_time().await.unwrap(); + assert!(remote_time.is_none()); + } + + #[tokio::test] + async fn current_remote_time_refreshes_expired_skew() { + let client = test_client(); + + { + let mut state = client.skew_state.write().await; + *state = Some(SkewState { + skew: TimeDuration::seconds(10), + expires_at: Instant::now() - StdDuration::from_secs(1), + }); + } + + let mocked_remote_time = remote_time_with_skew(180); + client.set_mock_remote_time(Some(mocked_remote_time)).await; + + { + let state = client.skew_state.read().await; + assert!(matches!( + state.as_ref().unwrap().status(Instant::now()), + SkewStatus::Expired + )); + } + + let remote_time = client.current_remote_time().await.unwrap(); + assert!(remote_time.is_some()); + + let state = client.skew_state.read().await; + assert!(matches!( + state.as_ref().unwrap().status(Instant::now()), + SkewStatus::Valid(_) + )); + assert_eq!( + remote_time + .unwrap() + .local_time_ahead_skew() + .whole_seconds() + .abs(), + 180 + ); + } + + #[tokio::test] + async fn current_remote_time_refreshes_when_missing() { + let client = test_client(); + + let mocked_remote_time = remote_time_with_skew(200); + client.set_mock_remote_time(Some(mocked_remote_time)).await; + + assert!(client.skew_state.read().await.is_none()); + let remote_time = client.current_remote_time().await.unwrap(); + assert!(remote_time.is_some()); + + assert!(client.skew_state.read().await.is_some()); + assert_eq!( + remote_time + .unwrap() + .local_time_ahead_skew() + .whole_seconds() + .abs(), + 200 + ); + } +}