Skip to content

Commit

Permalink
generalize global values
Browse files Browse the repository at this point in the history
  • Loading branch information
t-aleksander committed Jan 15, 2025
1 parent 97cb8c6 commit 104b668
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 111 deletions.
29 changes: 6 additions & 23 deletions src/db/models/settings.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,23 @@
use std::{
collections::HashMap,
str::FromStr,
sync::{RwLock, RwLockReadGuard},
};
use std::{collections::HashMap, str::FromStr};

use sqlx::{query, query_as, PgExecutor, PgPool, Type};
use struct_patch::Patch;
use thiserror::Error;

use crate::secret::SecretString;
use crate::{global_value, secret::SecretString};

// wrap in `Option` since a static cannot be initialized with a non-const function
static SETTINGS: RwLock<Option<Settings>> = RwLock::new(None);

pub(crate) fn set_settings(new_settings: Settings) {
*SETTINGS
.write()
.expect("Failed to acquire lock on current settings.") = Some(new_settings);
}

pub(crate) fn get_settings() -> RwLockReadGuard<'static, Option<Settings>> {
SETTINGS
.read()
.expect("Failed to acquire lock on current settings.")
}
global_value!(SETTINGS, Option<Settings>, None, set_settings, get_settings);

/// Initializes global `SETTINGS` struct at program startup
pub async fn initialize_current_settings(pool: &PgPool) -> Result<(), sqlx::Error> {
debug!("Initializing global settings strut");
match Settings::get(pool).await? {
Some(settings) => {
set_settings(settings);
set_settings(Some(settings));
}
None => {
debug!("Settings not found in DB. Using default values to initialize global settings struct");
set_settings(Settings::default());
set_settings(Some(Settings::default()));
}
}
Ok(())
Expand All @@ -47,7 +30,7 @@ pub async fn update_current_settings(
) -> Result<(), sqlx::Error> {
debug!("Updating current settings to: {new_settings:?}");
new_settings.save(pool).await?;
set_settings(new_settings);
set_settings(Some(new_settings));
Ok(())
}

Expand Down
29 changes: 10 additions & 19 deletions src/enterprise/license.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
use std::{
sync::{RwLock, RwLockReadGuard},
time::Duration,
};
use std::time::Duration;

use anyhow::Result;
use base64::prelude::*;
Expand All @@ -15,26 +12,20 @@ use tokio::time::sleep;

use crate::{
db::{models::settings::update_current_settings, Settings},
server_config, VERSION,
global_value, server_config, VERSION,
};

use super::limits::Counts;

const LICENSE_SERVER_URL: &str = "https://pkgs.defguard.net/api/license/renew";

static LICENSE: RwLock<Option<License>> = RwLock::new(None);

pub fn set_cached_license(license: Option<License>) {
*LICENSE
.write()
.expect("Failed to acquire lock on the license mutex.") = license;
}

pub fn get_cached_license() -> RwLockReadGuard<'static, Option<License>> {
LICENSE
.read()
.expect("Failed to acquire lock on the license mutex.")
}
global_value!(
LICENSE,
Option<License>,
None,
set_cached_license,
get_cached_license
);

tonic::include_proto!("license");

Expand Down Expand Up @@ -584,7 +575,7 @@ pub async fn run_periodic_license_check(pool: &PgPool) -> Result<(), LicenseErro
let license = get_cached_license();
debug!("Checking if the license {license:?} requires a renewal...");

if let Some(license) = &*license {
if let Some(license) = license.as_ref() {
if license.requires_renewal() {
// check if we are pass the maximum expiration date, after which we don't
// want to try to renew the license anymore
Expand Down
61 changes: 12 additions & 49 deletions src/enterprise/limits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::global_value;
use sqlx::{error::Error as SqlxError, query, PgPool};
use std::sync::{RwLock, RwLockReadGuard};

#[cfg(test)]
use super::license::get_cached_license;
Expand All @@ -11,60 +11,23 @@ pub const DEFAULT_DEVICES_LIMIT: u32 = 10;
pub const DEFAULT_LOCATIONS_LIMIT: u32 = 1;

#[derive(Debug, Default, Clone)]
pub(crate) struct Counts {
pub struct Counts {
user: u32,
device: u32,
wireguard_network: u32,
}

#[cfg(test)]
thread_local! {
static COUNTS: RwLock<Counts> = RwLock::new(Counts {
global_value!(
COUNTS,
Counts,
Counts {
user: 0,
device: 0,
wireguard_network: 0,
});
}

#[cfg(not(test))]
static COUNTS: RwLock<Counts> = RwLock::new(Counts {
user: 0,
device: 0,
wireguard_network: 0,
});

#[cfg(not(test))]
fn set_counts(new_counts: Counts) {
*COUNTS
.write()
.expect("Failed to acquire lock on the enterprise limit counts.") = new_counts;
}

#[cfg(not(test))]
pub(crate) fn get_counts() -> RwLockReadGuard<'static, Counts> {
COUNTS
.read()
.expect("Failed to acquire lock on the enterprise limit counts.")
}

#[cfg(test)]
fn set_counts(new_counts: Counts) {
COUNTS.with(|counts| {
*counts
.write()
.expect("Failed to acquire lock on the enterprise limit counts.") = new_counts;
});
}

#[cfg(test)]
pub(crate) fn get_counts() -> Counts {
COUNTS.with(|counts| {
counts
.read()
.expect("Failed to acquire lock on the enterprise limit counts.")
.clone()
})
}
wireguard_network: 0
},
set_counts,
get_counts
);

/// Update the counts of users, devices, and wireguard networks stored in the memory.
// TODO: Use it with database triggers when they are implemented
Expand Down Expand Up @@ -128,7 +91,7 @@ impl Counts {
let maybe_license = get_cached_license();

// validate limits against license if available, use defaults otherwise
match &*maybe_license {
match &maybe_license {
Some(license) => {
debug!("Cached license found. Validating license limits...");
self.is_over_license_limits(license)
Expand Down
49 changes: 49 additions & 0 deletions src/globals.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#[macro_export]
/// Specify a global value that can be accessed from anywhere in the application.
/// Pass the name of the global value, the type of the value, the initial value, the function name to set the value, and the function name to get the value.
///
/// The macro will also automatically generate boilerplate code for unit tests to work correctly.
macro_rules! global_value {
($name:ident, $type:ty, $init:expr, $set_fn:ident, $get_fn:ident) => {
use std::sync::RwLock;
#[cfg(not(test))]
use std::sync::RwLockReadGuard;

#[cfg(test)]
thread_local! {
static $name: RwLock<$type> = const { RwLock::new($init) };
}

#[cfg(not(test))]
static $name: RwLock<$type> = RwLock::new($init);

#[cfg(not(test))]
pub fn $set_fn(value: $type) {
*$name.write().expect("Failed to acquire lock on the mutex.") = value;
}

#[cfg(not(test))]
pub fn $get_fn() -> RwLockReadGuard<'static, $type> {
$name.read().expect("Failed to acquire lock on the mutex.")
}

#[cfg(test)]
pub fn $set_fn(new_value: $type) {
$name.with(|value| {
*value.write().expect("Failed to acquire lock on the mutex.") = new_value;
});
}

// This is not really a 1:1 replacement for the non-test RwLockReadGuard<'static, $type> as the RwLock may be tried to be
// dereferenced
#[cfg(test)]
pub fn $get_fn() -> $type {
$name.with(|value| {
value
.read()
.expect("Failed to acquire lock on the mutex.")
.clone()
})
}
};
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ pub mod config;
pub mod db;
pub mod enterprise;
mod error;
pub mod globals;
pub mod grpc;
pub mod handlers;
pub mod headers;
Expand Down
1 change: 0 additions & 1 deletion src/templates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,6 @@ pub fn gateway_reconnected_mail(
context.insert("gateway_ip", gateway_ip);
context.insert("network_name", network_name);
tera.add_raw_template("mail_gateway_reconnected", MAIL_GATEWAY_RECONNECTED)?;
println!("dupa");
Ok(tera.render("mail_gateway_reconnected", &context)?)
}

Expand Down
25 changes: 6 additions & 19 deletions src/updates.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use std::{
env,
sync::{RwLock, RwLockReadGuard},
};
use std::env;

use chrono::NaiveDate;
use semver::Version;

use crate::global_value;

const PRODUCT_NAME: &str = "Defguard";
const UPDATES_URL: &str = "https://update-service-dev.defguard.net/api/update/check";
const VERSION: &str = env!("CARGO_PKG_VERSION");

#[derive(Deserialize, Debug, Serialize)]
#[derive(Deserialize, Debug, Serialize, Clone)]
pub struct Update {
version: String,
release_date: NaiveDate,
Expand All @@ -20,19 +19,7 @@ pub struct Update {
notes: String,
}

static NEW_UPDATE: RwLock<Option<Update>> = RwLock::new(None);

fn set_update(update: Update) {
*NEW_UPDATE
.write()
.expect("Failed to acquire lock on the update.") = Some(update);
}

pub fn get_update() -> RwLockReadGuard<'static, Option<Update>> {
NEW_UPDATE
.read()
.expect("Failed to acquire lock on the update.")
}
global_value!(NEW_UPDATE, Option<Update>, None, set_update, get_update);

async fn fetch_update() -> Result<Update, anyhow::Error> {
let body = serde_json::json!({
Expand Down Expand Up @@ -63,7 +50,7 @@ pub(crate) async fn do_new_version_check() -> Result<(), anyhow::Error> {
update.version, update.release_date
);
}
set_update(update);
set_update(Some(update));
} else {
debug!("New version check done. You are using the latest version of Defguard.");
}
Expand Down

0 comments on commit 104b668

Please sign in to comment.