From b338cc8706f726dd432264e085bf898b38049d7b Mon Sep 17 00:00:00 2001 From: Jacek Chmielewski Date: Thu, 20 Nov 2025 17:12:12 +0100 Subject: [PATCH 01/17] Merge pull request #1706 from DefGuard/all-traffic-only Related issue: https://github.com/DefGuard/defguard/issues/880 Adds "force all traffic" option to enterprise settings. When selected, all clients are forced to route all traffic via the vpn. --- ...ba68521649ba86985d45049487ae50d7dfde8.json | 43 +++++++++ ...b715db1c1df67da573b0f132fea265e42b416.json | 32 ------- ...b3dbb0fc747ae0843baa001e47ea328e49c25.json | 27 ++++++ ...a5e7829eae217ad9f7f3b0b03aa02f8808dc2.json | 16 ---- .../db/models/enterprise_settings.rs | 33 +++++-- crates/defguard_core/src/grpc/mod.rs | 15 ++- .../integration/api/enterprise_settings.rs | 12 +-- ...51119122424_client_traffic_policy.down.sql | 13 +++ ...0251119122424_client_traffic_policy.up.sql | 19 ++++ proto | 2 +- web/src/i18n/en/index.ts | 23 ++++- web/src/i18n/i18n-types.ts | 92 +++++++++++++++---- web/src/i18n/pl/index.ts | 23 ++++- .../components/EnterpriseForm.tsx | 14 +-- .../TrafficPolicySelect.tsx | 77 ++++++++++++++++ .../components/TrafficPolicySelect/style.scss | 59 ++++++++++++ web/src/shared/types.ts | 8 +- 17 files changed, 399 insertions(+), 109 deletions(-) create mode 100644 .sqlx/query-160d23b882d0465fbc8c5453b7dba68521649ba86985d45049487ae50d7dfde8.json delete mode 100644 .sqlx/query-283e1c3d082f1388fc2b806bdcab715db1c1df67da573b0f132fea265e42b416.json create mode 100644 .sqlx/query-a644507ebcfb9ef04883ad8b07bb3dbb0fc747ae0843baa001e47ea328e49c25.json delete mode 100644 .sqlx/query-ccd62ea7526078c9db47812e7f6a5e7829eae217ad9f7f3b0b03aa02f8808dc2.json create mode 100644 migrations/20251119122424_client_traffic_policy.down.sql create mode 100644 migrations/20251119122424_client_traffic_policy.up.sql create mode 100644 web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx create mode 100644 web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss diff --git a/.sqlx/query-160d23b882d0465fbc8c5453b7dba68521649ba86985d45049487ae50d7dfde8.json b/.sqlx/query-160d23b882d0465fbc8c5453b7dba68521649ba86985d45049487ae50d7dfde8.json new file mode 100644 index 000000000..3b3177056 --- /dev/null +++ b/.sqlx/query-160d23b882d0465fbc8c5453b7dba68521649ba86985d45049487ae50d7dfde8.json @@ -0,0 +1,43 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT admin_device_management, client_traffic_policy \"client_traffic_policy: ClientTrafficPolicy\", only_client_activation FROM \"enterprisesettings\" WHERE id = 1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "admin_device_management", + "type_info": "Bool" + }, + { + "ordinal": 1, + "name": "client_traffic_policy: ClientTrafficPolicy", + "type_info": { + "Custom": { + "name": "client_traffic_policy", + "kind": { + "Enum": [ + "none", + "disable_all_traffic", + "force_all_traffic" + ] + } + } + } + }, + { + "ordinal": 2, + "name": "only_client_activation", + "type_info": "Bool" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "160d23b882d0465fbc8c5453b7dba68521649ba86985d45049487ae50d7dfde8" +} diff --git a/.sqlx/query-283e1c3d082f1388fc2b806bdcab715db1c1df67da573b0f132fea265e42b416.json b/.sqlx/query-283e1c3d082f1388fc2b806bdcab715db1c1df67da573b0f132fea265e42b416.json deleted file mode 100644 index c1696c5b0..000000000 --- a/.sqlx/query-283e1c3d082f1388fc2b806bdcab715db1c1df67da573b0f132fea265e42b416.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "SELECT admin_device_management, disable_all_traffic, only_client_activation FROM \"enterprisesettings\" WHERE id = 1", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "admin_device_management", - "type_info": "Bool" - }, - { - "ordinal": 1, - "name": "disable_all_traffic", - "type_info": "Bool" - }, - { - "ordinal": 2, - "name": "only_client_activation", - "type_info": "Bool" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "283e1c3d082f1388fc2b806bdcab715db1c1df67da573b0f132fea265e42b416" -} diff --git a/.sqlx/query-a644507ebcfb9ef04883ad8b07bb3dbb0fc747ae0843baa001e47ea328e49c25.json b/.sqlx/query-a644507ebcfb9ef04883ad8b07bb3dbb0fc747ae0843baa001e47ea328e49c25.json new file mode 100644 index 000000000..509af4943 --- /dev/null +++ b/.sqlx/query-a644507ebcfb9ef04883ad8b07bb3dbb0fc747ae0843baa001e47ea328e49c25.json @@ -0,0 +1,27 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE \"enterprisesettings\" SET admin_device_management = $1, client_traffic_policy = $2, only_client_activation = $3 WHERE id = 1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Bool", + { + "Custom": { + "name": "client_traffic_policy", + "kind": { + "Enum": [ + "none", + "disable_all_traffic", + "force_all_traffic" + ] + } + } + }, + "Bool" + ] + }, + "nullable": [] + }, + "hash": "a644507ebcfb9ef04883ad8b07bb3dbb0fc747ae0843baa001e47ea328e49c25" +} diff --git a/.sqlx/query-ccd62ea7526078c9db47812e7f6a5e7829eae217ad9f7f3b0b03aa02f8808dc2.json b/.sqlx/query-ccd62ea7526078c9db47812e7f6a5e7829eae217ad9f7f3b0b03aa02f8808dc2.json deleted file mode 100644 index 787a75d7d..000000000 --- a/.sqlx/query-ccd62ea7526078c9db47812e7f6a5e7829eae217ad9f7f3b0b03aa02f8808dc2.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE \"enterprisesettings\" SET admin_device_management = $1, disable_all_traffic = $2, only_client_activation = $3 WHERE id = 1", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Bool", - "Bool", - "Bool" - ] - }, - "nullable": [] - }, - "hash": "ccd62ea7526078c9db47812e7f6a5e7829eae217ad9f7f3b0b03aa02f8808dc2" -} diff --git a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs b/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs index f2e85ced8..d1c9be350 100644 --- a/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs +++ b/crates/defguard_core/src/enterprise/db/models/enterprise_settings.rs @@ -1,4 +1,4 @@ -use sqlx::{PgExecutor, query, query_as}; +use sqlx::{PgExecutor, Type, query, query_as}; use struct_patch::Patch; use crate::enterprise::is_enterprise_enabled; @@ -6,11 +6,11 @@ use crate::enterprise::is_enterprise_enabled; #[derive(Debug, Deserialize, Patch, Serialize)] #[patch(attribute(derive(Deserialize, Serialize)))] pub struct EnterpriseSettings { - // If true, only admins can manage devices + /// If true, only admins can manage devices pub admin_device_management: bool, - // If true, the option to route all traffic through the vpn is disabled in the client - pub disable_all_traffic: bool, - // If true, manual WireGuard setup is disabled + /// Describes allowed routing options for clients connecting to the instance. + pub client_traffic_policy: ClientTrafficPolicy, + /// If true, manual WireGuard setup is disabled pub only_client_activation: bool, } @@ -20,8 +20,8 @@ impl Default for EnterpriseSettings { fn default() -> Self { Self { admin_device_management: false, - disable_all_traffic: false, only_client_activation: false, + client_traffic_policy: ClientTrafficPolicy::default(), } } } @@ -39,7 +39,8 @@ impl EnterpriseSettings { let settings = query_as!( Self, "SELECT admin_device_management, \ - disable_all_traffic, only_client_activation \ + client_traffic_policy \"client_traffic_policy: ClientTrafficPolicy\", \ + only_client_activation \ FROM \"enterprisesettings\" WHERE id = 1", ) .fetch_optional(executor) @@ -57,11 +58,11 @@ impl EnterpriseSettings { query!( "UPDATE \"enterprisesettings\" SET \ admin_device_management = $1, \ - disable_all_traffic = $2, \ + client_traffic_policy = $2, \ only_client_activation = $3 \ WHERE id = 1", self.admin_device_management, - self.disable_all_traffic, + self.client_traffic_policy as ClientTrafficPolicy, self.only_client_activation, ) .execute(executor) @@ -70,3 +71,17 @@ impl EnterpriseSettings { Ok(()) } } + +/// Describes allowed traffic options for clients connecting to the instance. +#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Type, Debug, Default, Copy)] +#[sqlx(type_name = "client_traffic_policy", rename_all = "snake_case")] +#[serde(rename_all = "snake_case")] +pub enum ClientTrafficPolicy { + /// No restrictions + #[default] + None, + /// Clients are not allowed to route all traffic through the VPN. + DisableAllTraffic, + /// Clients are forced to route all traffic through the VPN. + ForceAllTraffic, +} diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 96dbd24e6..a4c4ba3dc 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -50,7 +50,10 @@ use crate::{ models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, }, enterprise::{ - db::models::{enterprise_settings::EnterpriseSettings, openid_provider::OpenIdProvider}, + db::models::{ + enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, + openid_provider::OpenIdProvider, + }, directory_sync::sync_user_groups_if_configured, grpc::polling::PollingServer, handlers::openid_login::{ @@ -806,7 +809,7 @@ pub struct InstanceInfo { url: Url, proxy_url: Url, username: String, - disable_all_traffic: bool, + client_traffic_policy: ClientTrafficPolicy, enterprise_enabled: bool, openid_display_name: Option, } @@ -829,7 +832,7 @@ impl InstanceInfo { url: config.url.clone(), proxy_url: config.enrollment_url.clone(), username: username.into(), - disable_all_traffic: enterprise_settings.disable_all_traffic, + client_traffic_policy: enterprise_settings.client_traffic_policy, enterprise_enabled: is_enterprise_enabled(), openid_display_name, } @@ -844,7 +847,11 @@ impl From for defguard_proto::proxy::InstanceInfo { url: instance.url.to_string(), proxy_url: instance.proxy_url.to_string(), username: instance.username, - disable_all_traffic: instance.disable_all_traffic, + // Ensure backwards compatibility. + #[allow(deprecated)] + disable_all_traffic: instance.client_traffic_policy + == ClientTrafficPolicy::DisableAllTraffic, + client_traffic_policy: Some(instance.client_traffic_policy as i32), enterprise_enabled: instance.enterprise_enabled, openid_display_name: instance.openid_display_name, } diff --git a/crates/defguard_core/tests/integration/api/enterprise_settings.rs b/crates/defguard_core/tests/integration/api/enterprise_settings.rs index f24582a65..c065dde6c 100644 --- a/crates/defguard_core/tests/integration/api/enterprise_settings.rs +++ b/crates/defguard_core/tests/integration/api/enterprise_settings.rs @@ -1,6 +1,6 @@ use defguard_core::{ enterprise::{ - db::models::enterprise_settings::EnterpriseSettings, + db::models::enterprise_settings::{ClientTrafficPolicy, EnterpriseSettings}, license::{get_cached_license, set_cached_license}, }, handlers::Auth, @@ -33,7 +33,7 @@ async fn test_only_enterprise_can_modify_enterpise_settings( // try to patch enterprise settings let settings = EnterpriseSettings { admin_device_management: false, - disable_all_traffic: false, + client_traffic_policy: ClientTrafficPolicy::None, only_client_activation: false, }; @@ -81,7 +81,7 @@ async fn test_admin_devices_management_is_enforced(_: PgPoolOptions, options: Pg // setup admin devices management let settings = EnterpriseSettings { admin_device_management: true, - disable_all_traffic: false, + client_traffic_policy: ClientTrafficPolicy::None, only_client_activation: false, }; let response = client @@ -177,7 +177,7 @@ async fn test_regular_user_device_management(_: PgPoolOptions, options: PgConnec // setup admin devices management let settings = EnterpriseSettings { admin_device_management: false, - disable_all_traffic: false, + client_traffic_policy: ClientTrafficPolicy::None, only_client_activation: false, }; let response = client @@ -265,7 +265,7 @@ async fn dg25_12_test_enforce_client_activation_only(_: PgPoolOptions, options: // disable manual device management let settings = EnterpriseSettings { admin_device_management: false, - disable_all_traffic: false, + client_traffic_policy: ClientTrafficPolicy::None, only_client_activation: true, }; let response = client @@ -346,7 +346,7 @@ async fn dg25_13_test_disable_device_config(_: PgPoolOptions, options: PgConnect // disable manual device management let settings = EnterpriseSettings { admin_device_management: false, - disable_all_traffic: false, + client_traffic_policy: ClientTrafficPolicy::None, only_client_activation: true, }; let response = client diff --git a/migrations/20251119122424_client_traffic_policy.down.sql b/migrations/20251119122424_client_traffic_policy.down.sql new file mode 100644 index 000000000..db730678d --- /dev/null +++ b/migrations/20251119122424_client_traffic_policy.down.sql @@ -0,0 +1,13 @@ +-- restore boolean `mfa_enabled` column +ALTER TABLE enterprisesettings ADD COLUMN "disable_all_traffic" BOOLEAN NOT NULL DEFAULT false; + +-- populate based on client traffic policy +UPDATE enterprisesettings +SET disable_all_traffic = CASE + WHEN client_traffic_policy = 'disable_all_traffic'::client_traffic_policy THEN true + ELSE false +END; + +-- drop new column and type +ALTER TABLE enterprisesettings DROP COLUMN "client_traffic_policy"; +DROP TYPE client_traffic_policy; diff --git a/migrations/20251119122424_client_traffic_policy.up.sql b/migrations/20251119122424_client_traffic_policy.up.sql new file mode 100644 index 000000000..a5bfb8d8a --- /dev/null +++ b/migrations/20251119122424_client_traffic_policy.up.sql @@ -0,0 +1,19 @@ +-- add enum representing client traffic policy +CREATE TYPE client_traffic_policy AS ENUM ( + 'none', + 'disable_all_traffic', + 'force_all_traffic' +); + +-- add column to `enterprisesettings` table +ALTER TABLE enterprisesettings ADD COLUMN "client_traffic_policy" client_traffic_policy NOT NULL DEFAULT 'none'; + +-- populate new column based on value in `disable_all_traffic` column +UPDATE enterprisesettings +SET client_traffic_policy = CASE + WHEN disable_all_traffic = true THEN 'disable_all_traffic'::client_traffic_policy + ELSE 'none'::client_traffic_policy +END; + +-- drop the `disable_all_traffic` column since it's no longer needed +ALTER TABLE enterprisesettings DROP COLUMN "disable_all_traffic"; diff --git a/proto b/proto index 96249ebde..74d60d917 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 96249ebde0556f4ae8c47eebc6015efb04ed0104 +Subproject commit 74d60d9171048ba0ccaf8a21b05950fb7a673f09 diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index f4b94015e..ce572bb6b 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -1708,16 +1708,29 @@ Licensing information: [https://docs.defguard.net/enterprise/license](https://do helper: "When this option is enabled, only users in the Admin group can manage devices in user profile (it's disabled for all other users)", }, - disableAllTraffic: { - label: 'Disable the option to route all traffic through VPN', - helper: - 'When this option is enabled, users will not be able to route all traffic through the VPN using the defguard client.', - }, manualConfig: { label: "Disable users' ability to manually configure WireGuard client", helper: "When this option is enabled, users won't be able to view or download configuration for the manual WireGuard client setup. Only the Defguard desktop client configuration will be available.", }, + clientTrafficPolicy: { + header: 'Client traffic policy', + none: { + label: 'None', + helper: + 'When this option is enabled, users will be able to select all routing options.', + }, + disableAllTraffic: { + label: 'Disable the option to route all traffic through VPN', + helper: + 'When this option is enabled, users will not be able to route all traffic through the VPN.', + }, + forceAllTraffic: { + label: 'Force the clients to route all traffic through VPN', + helper: + 'When this option is enabled, the users will always route all traffic through the VPN.', + }, + }, }, }, gatewayNotifications: { diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index 2313f93c1..517c50734 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -4082,16 +4082,6 @@ type RootTranslation = { */ helper: string } - disableAllTraffic: { - /** - * D​i​s​a​b​l​e​ ​t​h​e​ ​o​p​t​i​o​n​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​V​P​N - */ - label: string - /** - * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​n​o​t​ ​b​e​ ​a​b​l​e​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​ ​u​s​i​n​g​ ​t​h​e​ ​d​e​f​g​u​a​r​d​ ​c​l​i​e​n​t​. - */ - helper: string - } manualConfig: { /** * D​i​s​a​b​l​e​ ​u​s​e​r​s​'​ ​a​b​i​l​i​t​y​ ​t​o​ ​m​a​n​u​a​l​l​y​ ​c​o​n​f​i​g​u​r​e​ ​W​i​r​e​G​u​a​r​d​ ​c​l​i​e​n​t @@ -4102,6 +4092,42 @@ type RootTranslation = { */ helper: string } + clientTrafficPolicy: { + /** + * C​l​i​e​n​t​ ​t​r​a​f​f​i​c​ ​p​o​l​i​c​y + */ + header: string + none: { + /** + * N​o​n​e + */ + label: string + /** + * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​b​e​ ​a​b​l​e​ ​t​o​ ​s​e​l​e​c​t​ ​a​l​l​ ​r​o​u​t​i​n​g​ ​o​p​t​i​o​n​s​. + */ + helper: string + } + disableAllTraffic: { + /** + * D​i​s​a​b​l​e​ ​t​h​e​ ​o​p​t​i​o​n​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​V​P​N + */ + label: string + /** + * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​n​o​t​ ​b​e​ ​a​b​l​e​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. + */ + helper: string + } + forceAllTraffic: { + /** + * F​o​r​c​e​ ​t​h​e​ ​c​l​i​e​n​t​s​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​V​P​N + */ + label: string + /** + * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​t​h​e​ ​u​s​e​r​s​ ​w​i​l​l​ ​a​l​w​a​y​s​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. + */ + helper: string + } + } } } gatewayNotifications: { @@ -10813,16 +10839,6 @@ export type TranslationFunctions = { */ helper: () => LocalizedString } - disableAllTraffic: { - /** - * Disable the option to route all traffic through VPN - */ - label: () => LocalizedString - /** - * When this option is enabled, users will not be able to route all traffic through the VPN using the defguard client. - */ - helper: () => LocalizedString - } manualConfig: { /** * Disable users' ability to manually configure WireGuard client @@ -10833,6 +10849,42 @@ export type TranslationFunctions = { */ helper: () => LocalizedString } + clientTrafficPolicy: { + /** + * Client traffic policy + */ + header: () => LocalizedString + none: { + /** + * None + */ + label: () => LocalizedString + /** + * When this option is enabled, users will be able to select all routing options. + */ + helper: () => LocalizedString + } + disableAllTraffic: { + /** + * Disable the option to route all traffic through VPN + */ + label: () => LocalizedString + /** + * When this option is enabled, users will not be able to route all traffic through the VPN. + */ + helper: () => LocalizedString + } + forceAllTraffic: { + /** + * Force the clients to route all traffic through VPN + */ + label: () => LocalizedString + /** + * When this option is enabled, the users will always route all traffic through the VPN. + */ + helper: () => LocalizedString + } + } } } gatewayNotifications: { diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 652590151..7356101f6 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -1494,16 +1494,29 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe helper: 'Kiedy ta opcja jest włączona, tylko użytkownicy w grupie "Admin" mogą zarządzać urządzeniami w profilu użytkownika', }, - disableAllTraffic: { - label: 'Zablokuj możliwość przekierowania całego ruchu przez VPN', - helper: - 'Kiedy ta opcja jest włączona, użytkownicy nie będą mogli przekierować całego ruchu przez VPN za pomocą klienta Defguard.', - }, manualConfig: { label: 'Wyłącz manualną konfigurację WireGuard', helper: 'Kiedy ta opcja jest włączona, użytkownicy nie będą mogli pobrać ani wyświetlić danych do manualnej konfiguracji WireGuard. Możliwe będzie wyłącznie skonfigurowanie klienta Defguard.', }, + clientTrafficPolicy: { + header: 'Polityka przekierowania ruchu klientów', + none: { + label: 'Brak', + helper: + 'Kiedy ta opcja jest włączona, użytkownicy mogą wybierać dowolny typ przekierowania ruchu.', + }, + disableAllTraffic: { + label: 'Zablokuj możliwość przekierowania całego ruchu przez VPN', + helper: + 'Kiedy ta opcja jest włączona, użytkownicy nie będą mogli przekierować całego ruchu przez VPN.', + }, + forceAllTraffic: { + label: 'Wymuś przekierowanie całego ruchu przez VPN', + helper: + 'Kiedy ta opcja jest włączona, użytkownicy będą zawsze przekierowywać cały ruch przez VPN.', + }, + } }, }, gatewayNotifications: { diff --git a/web/src/pages/settings/components/EnterpriseSettings/components/EnterpriseForm.tsx b/web/src/pages/settings/components/EnterpriseSettings/components/EnterpriseForm.tsx index d7df932de..c8bcdc885 100644 --- a/web/src/pages/settings/components/EnterpriseSettings/components/EnterpriseForm.tsx +++ b/web/src/pages/settings/components/EnterpriseSettings/components/EnterpriseForm.tsx @@ -12,6 +12,7 @@ import useApi from '../../../../../shared/hooks/useApi'; import { useToaster } from '../../../../../shared/hooks/useToaster'; import { MutationKeys } from '../../../../../shared/mutations'; import { QueryKeys } from '../../../../../shared/queries'; +import { ClientTrafficPolicySelect } from './TrafficPolicySelect/TrafficPolicySelect'; export const EnterpriseForm = () => { const { LL } = useI18nContext(); @@ -77,17 +78,10 @@ export const EnterpriseForm = () => {
- - mutate({ disable_all_traffic: !settings.disable_all_traffic }) - } + mutate({ client_traffic_policy: value })} + fieldValue={settings.client_traffic_policy} /> - - {parse(LL.settingsPage.enterprise.fields.disableAllTraffic.helper())} -
diff --git a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx new file mode 100644 index 000000000..43fa5f3f9 --- /dev/null +++ b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx @@ -0,0 +1,77 @@ +import './style.scss'; +import clsx from 'clsx'; +import parse from 'html-react-parser'; +import { useMemo } from 'react'; +import { useI18nContext } from '../../../../../../i18n/i18n-react'; +import { Helper } from '../../../../../../shared/defguard-ui/components/Layout/Helper/Helper'; +import { RadioButton } from '../../../../../../shared/defguard-ui/components/Layout/RadioButton/Radiobutton'; +import type { SelectOption } from '../../../../../../shared/defguard-ui/components/Layout/Select/types'; +import { ClientTrafficPolicy } from '../../../../../../shared/types'; + +type Props = { + onChange: (event: ClientTrafficPolicy) => void; + fieldValue: ClientTrafficPolicy; +}; + +export const ClientTrafficPolicySelect = ({ onChange, fieldValue }: Props) => { + const { LL } = useI18nContext(); + const options = useMemo( + (): SelectOption[] => [ + { + key: ClientTrafficPolicy.NONE, + value: ClientTrafficPolicy.NONE, + label: LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.label(), + meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.helper(), + }, + { + key: ClientTrafficPolicy.DISABLE_ALL_TRAFFIC, + value: ClientTrafficPolicy.DISABLE_ALL_TRAFFIC, + label: + LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.label(), + meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.helper(), + }, + { + key: ClientTrafficPolicy.FORCE_ALL_TRAFFIC, + value: ClientTrafficPolicy.FORCE_ALL_TRAFFIC, + label: + LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.label(), + meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.helper(), + }, + ], + [ + LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.label, + LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.helper, + LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.label, + LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.helper, + LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.label, + LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.helper, + ], + ); + + return ( +
+ + {options.map(({ key, value, label, meta, disabled = false }) => { + const active = fieldValue === value; + return ( +
{ + if (!disabled) { + onChange(value); + } + }} + > +

{label}

+ + {parse(meta)} +
+ ); + })} +
+ ); +}; diff --git a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss new file mode 100644 index 000000000..692d1123b --- /dev/null +++ b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss @@ -0,0 +1,59 @@ +.client-traffic-policy-select { + display: flex; + flex-flow: column; + row-gap: var(--spacing-s); + margin-bottom: 25px; + + .client-traffic-policy { + display: flex; + align-items: center; + justify-content: space-between; + column-gap: var(--spacing-xs); + min-height: 30px; + border: 1px solid var(--border-primary); + padding: var(--spacing-xs) var(--spacing-s); + border-radius: 10px; + cursor: pointer; + user-select: none; + transition-property: border-color, opacity; + @include animate-standard; + + &:not(.active) { + &:hover { + border-color: var(--border-separator); + } + } + + &.active { + border-color: var(--surface-main-primary); + } + + &.active, + &:hover { + .label { + color: var(--text-body-primary); + } + } + + &.disabled { + opacity: 0.6; + cursor: not-allowed; + background-color: var(--surface-secondary); + + .label { + color: var(--text-body-disabled); + } + + &:hover { + border-color: var(--border-primary); + } + } + + .label { + color: var(--text-body-secondary); + transition-property: color; + @include typography(app-modal-1); + @include animate-standard; + } + } +} diff --git a/web/src/shared/types.ts b/web/src/shared/types.ts index 3c181cb11..33258f66b 100644 --- a/web/src/shared/types.ts +++ b/web/src/shared/types.ts @@ -1121,9 +1121,15 @@ export type SettingsGatewayNotifications = { gateway_disconnect_notifications_reconnect_notification_enabled: boolean; }; +export enum ClientTrafficPolicy { + NONE = 'none', + DISABLE_ALL_TRAFFIC = 'disable_all_traffic', + FORCE_ALL_TRAFFIC = 'force_all_traffic', +} + export type SettingsEnterprise = { admin_device_management: boolean; - disable_all_traffic: boolean; + client_traffic_policy: ClientTrafficPolicy; only_client_activation: boolean; }; From 0659ae1b9fcd5a9d3e03a5c9c7c7f572e86955af Mon Sep 17 00:00:00 2001 From: jakub-tldr <78603704+jakub-tldr@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:18:35 +0100 Subject: [PATCH 02/17] Filter MFA locations on network devices modal, block creating devices without name (#1719) * filter mfa locations, validate ip/domain in wizard * Reject device without name --- crates/defguard_core/src/grpc/enrollment.rs | 5 +++++ .../steps/MethodStep/MethodStep.tsx | 12 +++++++----- .../WizardNetworkConfiguration.tsx | 5 ++++- web/src/shared/validators.ts | 18 ++++++++++++------ 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/crates/defguard_core/src/grpc/enrollment.rs b/crates/defguard_core/src/grpc/enrollment.rs index c7a43069b..13b0733ea 100644 --- a/crates/defguard_core/src/grpc/enrollment.rs +++ b/crates/defguard_core/src/grpc/enrollment.rs @@ -695,6 +695,11 @@ impl EnrollmentServer { None, true, ); + if device.name.is_empty() { + return Err(Status::invalid_argument( + "Cannot add a new device with no name. You may be trying to add a new user device as a network device. Defguard CLI supports only network devices.", + )); + } let device = device.save(&mut *transaction).await.map_err(|err| { error!( "Failed to save device {}, pubkey {} for user {}({:?}): {err}", diff --git a/web/src/pages/devices/modals/AddStandaloneDeviceModal/steps/MethodStep/MethodStep.tsx b/web/src/pages/devices/modals/AddStandaloneDeviceModal/steps/MethodStep/MethodStep.tsx index a054a9150..ef6b9f085 100644 --- a/web/src/pages/devices/modals/AddStandaloneDeviceModal/steps/MethodStep/MethodStep.tsx +++ b/web/src/pages/devices/modals/AddStandaloneDeviceModal/steps/MethodStep/MethodStep.tsx @@ -88,11 +88,13 @@ export const MethodStep = () => { // biome-ignore lint/correctness/useExhaustiveDependencies: migration, checkMeLater useEffect(() => { if (networks) { - const options: SelectOption[] = networks.map((n) => ({ - key: n.id, - value: n.id, - label: n.name, - })); + const options: SelectOption[] = networks + .filter((n) => n.location_mfa_mode === 'disabled') + .map((n) => ({ + key: n.id, + value: n.id, + label: n.name, + })); setState({ networks, networkOptions: options, diff --git a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx index 8d416e221..6ecdd2034 100644 --- a/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx +++ b/web/src/pages/wizard/components/WizardNetworkConfiguration/WizardNetworkConfiguration.tsx @@ -118,7 +118,10 @@ export const WizardNetworkConfiguration = () => { .string() .trim() .min(1, LL.form.error.required()) - .refine((val) => validateIpOrDomain(val), LL.form.error.endpoint()), + .refine( + (val) => validateIpOrDomain(val, false, true), + LL.form.error.endpoint(), + ), port: z .number({ invalid_type_error: LL.form.error.invalid(), diff --git a/web/src/shared/validators.ts b/web/src/shared/validators.ts index 77df62d88..8d9966400 100644 --- a/web/src/shared/validators.ts +++ b/web/src/shared/validators.ts @@ -1,6 +1,5 @@ import ipaddr from 'ipaddr.js'; import { z } from 'zod'; - import { patternValidDomain, patternValidWireguardKey } from './patterns'; export const validateWireguardPublicKey = (props: { @@ -24,11 +23,13 @@ export const validateIpOrDomain = ( allowMask = false, allowIPv6 = false, ): boolean => { - return ( - (allowIPv6 && validateIPv6(val, allowMask)) || - validateIPv4(val, allowMask) || - patternValidDomain.test(val) - ); + const hasLetter = /\p{L}/u.test(val); + const hasColon = /:/.test(val); + if (!hasLetter || hasColon) { + return (allowIPv6 && validateIPv6(val, allowMask)) || validateIPv4(val, allowMask); + } else { + return patternValidDomain.test(val); + } }; // Returns false when invalid @@ -41,6 +42,7 @@ export const validateIpList = ( .replace(' ', '') .split(splitWith) .every((el) => { + if (!el.includes('/') && allowMasks) return false; return validateIPv4(el, allowMasks) || validateIPv6(el, allowMasks); }); }; @@ -76,6 +78,10 @@ export const validateIPv4 = (ip: string, allowMask = false): boolean => { return ipaddr.IPv4.isValidCIDR(ip); } } + const ipv4Pattern = /^(\d{1,3}\.){3}\d{1,3}$/; + if (!ipv4Pattern.test(ip)) { + return false; + } return ipaddr.IPv4.isValid(ip); }; From da226a8ca11ab1e8bbb2b6f0fcc3f69cc254b39e Mon Sep 17 00:00:00 2001 From: Jacek Chmielewski Date: Mon, 24 Nov 2025 09:56:41 +0100 Subject: [PATCH 03/17] Fix traffic policy settings styling (#1720) * fix client traffic policy helpers styling * remove unused useMemo deps * tweak the header --- web/src/i18n/en/index.ts | 6 +- web/src/i18n/i18n-types.ts | 12 +-- web/src/i18n/pl/index.ts | 6 +- .../TrafficPolicySelect.tsx | 75 +++++++++++-------- .../components/TrafficPolicySelect/style.scss | 13 ++++ 5 files changed, 69 insertions(+), 43 deletions(-) diff --git a/web/src/i18n/en/index.ts b/web/src/i18n/en/index.ts index ce572bb6b..fa88c26ca 100644 --- a/web/src/i18n/en/index.ts +++ b/web/src/i18n/en/index.ts @@ -1718,17 +1718,17 @@ Licensing information: [https://docs.defguard.net/enterprise/license](https://do none: { label: 'None', helper: - 'When this option is enabled, users will be able to select all routing options.', + 'None - When this option is enabled, users will be able to select all routing options.', }, disableAllTraffic: { label: 'Disable the option to route all traffic through VPN', helper: - 'When this option is enabled, users will not be able to route all traffic through the VPN.', + 'Disable all traffic - When this option is enabled, users will not be able to route all traffic through the VPN.', }, forceAllTraffic: { label: 'Force the clients to route all traffic through VPN', helper: - 'When this option is enabled, the users will always route all traffic through the VPN.', + 'Force all traffic - When this option is enabled, the users will always route all traffic through the VPN.', }, }, }, diff --git a/web/src/i18n/i18n-types.ts b/web/src/i18n/i18n-types.ts index 517c50734..a2ce0a212 100644 --- a/web/src/i18n/i18n-types.ts +++ b/web/src/i18n/i18n-types.ts @@ -4103,7 +4103,7 @@ type RootTranslation = { */ label: string /** - * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​b​e​ ​a​b​l​e​ ​t​o​ ​s​e​l​e​c​t​ ​a​l​l​ ​r​o​u​t​i​n​g​ ​o​p​t​i​o​n​s​. + * N​o​n​e​ ​-​ ​W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​b​e​ ​a​b​l​e​ ​t​o​ ​s​e​l​e​c​t​ ​a​l​l​ ​r​o​u​t​i​n​g​ ​o​p​t​i​o​n​s​. */ helper: string } @@ -4113,7 +4113,7 @@ type RootTranslation = { */ label: string /** - * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​n​o​t​ ​b​e​ ​a​b​l​e​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. + * D​i​s​a​b​l​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​-​ ​W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​u​s​e​r​s​ ​w​i​l​l​ ​n​o​t​ ​b​e​ ​a​b​l​e​ ​t​o​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. */ helper: string } @@ -4123,7 +4123,7 @@ type RootTranslation = { */ label: string /** - * W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​t​h​e​ ​u​s​e​r​s​ ​w​i​l​l​ ​a​l​w​a​y​s​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. + * F​o​r​c​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​-​ ​W​h​e​n​ ​t​h​i​s​ ​o​p​t​i​o​n​ ​i​s​ ​e​n​a​b​l​e​d​,​ ​t​h​e​ ​u​s​e​r​s​ ​w​i​l​l​ ​a​l​w​a​y​s​ ​r​o​u​t​e​ ​a​l​l​ ​t​r​a​f​f​i​c​ ​t​h​r​o​u​g​h​ ​t​h​e​ ​V​P​N​. */ helper: string } @@ -10860,7 +10860,7 @@ export type TranslationFunctions = { */ label: () => LocalizedString /** - * When this option is enabled, users will be able to select all routing options. + * None - When this option is enabled, users will be able to select all routing options. */ helper: () => LocalizedString } @@ -10870,7 +10870,7 @@ export type TranslationFunctions = { */ label: () => LocalizedString /** - * When this option is enabled, users will not be able to route all traffic through the VPN. + * Disable all traffic - When this option is enabled, users will not be able to route all traffic through the VPN. */ helper: () => LocalizedString } @@ -10880,7 +10880,7 @@ export type TranslationFunctions = { */ label: () => LocalizedString /** - * When this option is enabled, the users will always route all traffic through the VPN. + * Force all traffic - When this option is enabled, the users will always route all traffic through the VPN. */ helper: () => LocalizedString } diff --git a/web/src/i18n/pl/index.ts b/web/src/i18n/pl/index.ts index 7356101f6..325591623 100644 --- a/web/src/i18n/pl/index.ts +++ b/web/src/i18n/pl/index.ts @@ -1504,17 +1504,17 @@ Uwaga, podane tutaj konfiguracje nie posiadają klucza prywatnego. Musisz uzupe none: { label: 'Brak', helper: - 'Kiedy ta opcja jest włączona, użytkownicy mogą wybierać dowolny typ przekierowania ruchu.', + 'Brak - Kiedy ta opcja jest włączona, użytkownicy mogą wybrać dowolny typ przekierowania ruchu.', }, disableAllTraffic: { label: 'Zablokuj możliwość przekierowania całego ruchu przez VPN', helper: - 'Kiedy ta opcja jest włączona, użytkownicy nie będą mogli przekierować całego ruchu przez VPN.', + 'Zablokuj przekierowanie całego ruchu - Kiedy ta opcja jest włączona, użytkownicy nie będą mogli przekierować całego ruchu przez VPN.', }, forceAllTraffic: { label: 'Wymuś przekierowanie całego ruchu przez VPN', helper: - 'Kiedy ta opcja jest włączona, użytkownicy będą zawsze przekierowywać cały ruch przez VPN.', + 'Wymuś przekierowanie całego ruchu - Kiedy ta opcja jest włączona, użytkownicy będą zawsze przekierowywać cały ruch przez VPN.', }, } }, diff --git a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx index 43fa5f3f9..2faf81784 100644 --- a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx +++ b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/TrafficPolicySelect.tsx @@ -1,9 +1,8 @@ import './style.scss'; import clsx from 'clsx'; -import parse from 'html-react-parser'; import { useMemo } from 'react'; import { useI18nContext } from '../../../../../../i18n/i18n-react'; -import { Helper } from '../../../../../../shared/defguard-ui/components/Layout/Helper/Helper'; +import { MessageBox } from '../../../../../../shared/defguard-ui/components/Layout/MessageBox/MessageBox'; import { RadioButton } from '../../../../../../shared/defguard-ui/components/Layout/RadioButton/Radiobutton'; import type { SelectOption } from '../../../../../../shared/defguard-ui/components/Layout/Select/types'; import { ClientTrafficPolicy } from '../../../../../../shared/types'; @@ -21,57 +20,71 @@ export const ClientTrafficPolicySelect = ({ onChange, fieldValue }: Props) => { key: ClientTrafficPolicy.NONE, value: ClientTrafficPolicy.NONE, label: LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.label(), - meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.helper(), }, { key: ClientTrafficPolicy.DISABLE_ALL_TRAFFIC, value: ClientTrafficPolicy.DISABLE_ALL_TRAFFIC, label: LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.label(), - meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.helper(), }, { key: ClientTrafficPolicy.FORCE_ALL_TRAFFIC, value: ClientTrafficPolicy.FORCE_ALL_TRAFFIC, label: LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.label(), - meta: LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.helper(), }, ], [ LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.label, - LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.helper, LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.label, - LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.helper, LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.label, - LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.helper, ], ); return ( -
- - {options.map(({ key, value, label, meta, disabled = false }) => { - const active = fieldValue === value; - return ( -
{ - if (!disabled) { - onChange(value); - } - }} - > -

{label}

- - {parse(meta)} -
- ); - })} +
+
+

{LL.settingsPage.enterprise.fields.clientTrafficPolicy.header()}

+
+
+ +
    +
  • +

    {LL.settingsPage.enterprise.fields.clientTrafficPolicy.none.helper()}

    +
  • +
  • +

    + {LL.settingsPage.enterprise.fields.clientTrafficPolicy.disableAllTraffic.helper()} +

    +
  • +
  • +

    + {LL.settingsPage.enterprise.fields.clientTrafficPolicy.forceAllTraffic.helper()} +

    +
  • +
+
+ {options.map(({ key, value, label, disabled = false }) => { + const active = fieldValue === value; + return ( +
{ + if (!disabled) { + onChange(value); + } + }} + > +

{label}

+ +
+ ); + })} +
); }; diff --git a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss index 692d1123b..ad3f91753 100644 --- a/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss +++ b/web/src/pages/settings/components/EnterpriseSettings/components/TrafficPolicySelect/style.scss @@ -56,4 +56,17 @@ @include animate-standard; } } + + #client-traffic-policy-message-box { + ul { + list-style-position: inside; + margin-top: 8px; + + li { + p { + display: inline; + } + } + } + } } From 5aa68b2ea2df81e99af6a9bd237c55ab0036ba0a Mon Sep 17 00:00:00 2001 From: jakub-tldr <78603704+jakub-tldr@users.noreply.github.com> Date: Tue, 25 Nov 2025 12:22:28 +0100 Subject: [PATCH 04/17] Fix validator for ipv4 with port (#1723) --- web/src/shared/validators.ts | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/web/src/shared/validators.ts b/web/src/shared/validators.ts index 8d9966400..0480c722c 100644 --- a/web/src/shared/validators.ts +++ b/web/src/shared/validators.ts @@ -79,9 +79,19 @@ export const validateIPv4 = (ip: string, allowMask = false): boolean => { } } const ipv4Pattern = /^(\d{1,3}\.){3}\d{1,3}$/; - if (!ipv4Pattern.test(ip)) { + const ipv4WithPortPattern = /^(\d{1,3}\.){3}\d{1,3}(:\d{1,5})?$/; + if (!ipv4Pattern.test(ip) && !ipv4WithPortPattern.test(ip)) { return false; } + + if (ipv4WithPortPattern.test(ip)) { + const [address, port] = ip.split(':'); + ip = address; + if (!validatePort(port)) { + return false; + } + } + return ipaddr.IPv4.isValid(ip); }; From 3b3dc271ba3096651fca4cba1e735ac6b79bd3e5 Mon Sep 17 00:00:00 2001 From: Jacek Chmielewski Date: Wed, 26 Nov 2025 10:32:34 +0100 Subject: [PATCH 05/17] fix ipv4 validator (#1726) --- flake.lock | 12 ++++++------ flake.nix | 1 + web/src/shared/validators.ts | 2 +- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/flake.lock b/flake.lock index 3de9f99f5..8f18766cd 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1763283776, - "narHash": "sha256-Y7TDFPK4GlqrKrivOcsHG8xSGqQx3A6c+i7novT85Uk=", + "lastModified": 1763966396, + "narHash": "sha256-6eeL1YPcY1MV3DDStIDIdy/zZCDKgHdkCmsrLJFiZf0=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "50a96edd8d0db6cc8db57dab6bb6d6ee1f3dc49a", + "rev": "5ae3b07d8d6527c42f17c876e404993199144b6a", "type": "github" }, "original": { @@ -48,11 +48,11 @@ ] }, "locked": { - "lastModified": 1763347184, - "narHash": "sha256-6QH8hpCYJxifvyHEYg+Da0BotUn03BwLIvYo3JAxuqQ=", + "lastModified": 1764124769, + "narHash": "sha256-vcoOEy3i8AGJi3Y2C48hrf6CuL2h8W1gLe1gNt72Kxg=", "owner": "oxalica", "repo": "rust-overlay", - "rev": "08895cce80433978d5bfd668efa41c5e24578cbd", + "rev": "5da8c00313b4434f00aed6b4c94cd3b207bafdc5", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index f2e6a984c..1ee1a20d8 100644 --- a/flake.nix +++ b/flake.nix @@ -49,6 +49,7 @@ # Specify the rust-src path (many editors rely on this) RUST_SRC_PATH = "${rustToolchain}/lib/rustlib/src/rust/library"; PLAYWRIGHT_BROWSERS_PATH = "${pkgs.playwright-driver.browsers}"; + PLAYWRIGHT_SKIP_VALIDATE_HOST_REQUIREMENTS = true; }; }); } diff --git a/web/src/shared/validators.ts b/web/src/shared/validators.ts index 0480c722c..44d3da6c5 100644 --- a/web/src/shared/validators.ts +++ b/web/src/shared/validators.ts @@ -79,7 +79,7 @@ export const validateIPv4 = (ip: string, allowMask = false): boolean => { } } const ipv4Pattern = /^(\d{1,3}\.){3}\d{1,3}$/; - const ipv4WithPortPattern = /^(\d{1,3}\.){3}\d{1,3}(:\d{1,5})?$/; + const ipv4WithPortPattern = /^(\d{1,3}\.){3}\d{1,3}:\d{1,5}$/; if (!ipv4Pattern.test(ip) && !ipv4WithPortPattern.test(ip)) { return false; } From 9d89d85b065681c85b31cf1262608c86caa36021 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Thu, 27 Nov 2025 13:12:18 +0100 Subject: [PATCH 06/17] Reverse gRPC communication --- Cargo.lock | 106 +- Cargo.toml | 3 +- crates/defguard_core/src/db/models/gateway.rs | 105 ++ crates/defguard_core/src/db/models/mod.rs | 1 + .../src/db/models/polling_token.rs | 11 +- .../defguard_core/src/db/models/wireguard.rs | 111 ++- .../src/enterprise/firewall/mod.rs | 2 +- .../defguard_core/src/grpc/gateway/handler.rs | 368 +++++++ crates/defguard_core/src/grpc/gateway/mod.rs | 902 ++++++++++-------- .../defguard_core/src/grpc/gateway/tests.rs | 87 ++ crates/defguard_core/src/grpc/mod.rs | 51 +- .../20251125072923_network_gateways.down.sql | 3 + .../20251125072923_network_gateways.up.sql | 20 + proto | 2 +- 14 files changed, 1264 insertions(+), 508 deletions(-) create mode 100644 crates/defguard_core/src/db/models/gateway.rs create mode 100644 crates/defguard_core/src/grpc/gateway/handler.rs create mode 100644 crates/defguard_core/src/grpc/gateway/tests.rs create mode 100644 migrations/20251125072923_network_gateways.down.sql create mode 100644 migrations/20251125072923_network_gateways.up.sql diff --git a/Cargo.lock b/Cargo.lock index a0722d94d..78a69433d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.46" +version = "1.2.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97463e1064cb1b1c1384ad0a0b9c8abd0988e2a91f52606c80ef14aadb63e36" +checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" dependencies = [ "find-msvc-tools", "jobserver", @@ -669,9 +669,9 @@ checksum = "bba18ee93d577a8428902687bcc2b6b45a56b1981a1f6d779731c86cc4c5db18" [[package]] name = "clap" -version = "4.5.52" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa8120877db0e5c011242f96806ce3c94e0737ab8108532a76a3300a01db2ab8" +checksum = "c9e340e012a1bf4935f5282ed1436d1489548e8f72308207ea5df0e23d2d03f8" dependencies = [ "clap_builder", "clap_derive", @@ -679,9 +679,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.52" +version = "4.5.53" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02576b399397b659c26064fbc92a75fede9d18ffd5f80ca1cd74ddab167016e1" +checksum = "d76b5d13eaa18c901fd2f7fca939fefe3a0727a953561fefdf3b2922b8569d00" dependencies = [ "anstream", "anstyle", @@ -835,9 +835,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.3.0" +version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" +checksum = "5eb8a2a1cd12ab0d987a5d5e825195d372001a4094a0376319d5a0ad71c1ba0d" dependencies = [ "crc-catalog", ] @@ -2001,7 +2001,7 @@ dependencies = [ "futures-core", "futures-sink", "http", - "indexmap 2.12.0", + "indexmap 2.12.1", "slab", "tokio", "tokio-util", @@ -2054,9 +2054,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419bdc4f6a9207fbeba6d11b604d481addf78ecd10c11ad51e76c2f6482748d" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" [[package]] name = "hashlink" @@ -2154,12 +2154,11 @@ dependencies = [ [[package]] name = "http" -version = "1.3.1" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" +checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" dependencies = [ "bytes", - "fnv", "itoa", ] @@ -2484,12 +2483,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.12.0" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717a8d2a5a929a1a2eb43a12812498ed141a0bcfb7e8f7844fbdbe4303bba9f" +checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2" dependencies = [ "equivalent", - "hashbrown 0.16.0", + "hashbrown 0.16.1", "serde", "serde_core", ] @@ -3594,9 +3593,9 @@ checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" [[package]] name = "pest" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "989e7521a040efde50c3ab6bbadafbe15ab6dc042686926be59ac35d74607df4" +checksum = "cbcfd20a6d4eeba40179f05735784ad32bdaef05ce8e8af05f180d45bb3e7e22" dependencies = [ "memchr", "ucd-trie", @@ -3604,9 +3603,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "187da9a3030dbafabbbfb20cb323b976dc7b7ce91fcd84f2f74d6e31d378e2de" +checksum = "51f72981ade67b1ca6adc26ec221be9f463f2b5839c7508998daa17c23d94d7f" dependencies = [ "pest", "pest_generator", @@ -3614,9 +3613,9 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49b401d98f5757ebe97a26085998d6c0eecec4995cad6ab7fc30ffdf4b052843" +checksum = "dee9efd8cdb50d719a80088b76f81aec7c41ed6d522ee750178f83883d271625" dependencies = [ "pest", "pest_meta", @@ -3627,9 +3626,9 @@ dependencies = [ [[package]] name = "pest_meta" -version = "2.8.3" +version = "2.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72f27a2cfee9f9039c4d86faa5af122a0ac3851441a34865b8a043b46be0065a" +checksum = "bf1d70880e76bdc13ba52eafa6239ce793d85c8e43896507e43dd8984ff05b82" dependencies = [ "pest", "sha2", @@ -3642,7 +3641,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "indexmap 2.12.0", + "indexmap 2.12.1", ] [[package]] @@ -4351,13 +4350,12 @@ dependencies = [ [[package]] name = "rust-ini" -version = "0.21.1" +version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e310ef0e1b6eeb79169a1171daf9abcb87a2e17c03bee2c4bb100b55c75409f" +checksum = "796e8d2b6696392a43bea58116b667fb4c29727dc5abd27d6acf338bb4f688c7" dependencies = [ "cfg-if", "ordered-multimap", - "trim-in-place", ] [[package]] @@ -4653,7 +4651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2f2d7ff8a2140333718bb329f5c40fc5f0865b84c426183ce14c97d2ab8154f" dependencies = [ "form_urlencoded", - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "ryu", "serde_core", @@ -4725,7 +4723,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.12.0", + "indexmap 2.12.1", "schemars 0.9.0", "schemars 1.1.0", "serde_core", @@ -4752,7 +4750,7 @@ version = "0.9.34+deprecated" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "itoa", "ryu", "serde", @@ -5002,7 +5000,7 @@ dependencies = [ "futures-util", "hashbrown 0.15.5", "hashlink", - "indexmap 2.12.0", + "indexmap 2.12.1", "ipnetwork", "log", "memchr", @@ -5320,9 +5318,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.110" +version = "2.0.111" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a99801b5bd34ede4cf3fc688c5919368fea4e4814a4664359503e6015b280aea" +checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87" dependencies = [ "proc-macro2", "quote", @@ -5625,7 +5623,7 @@ version = "0.23.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "toml_datetime", "toml_parser", "winnow", @@ -5744,7 +5742,7 @@ checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" dependencies = [ "futures-core", "futures-util", - "indexmap 2.12.0", + "indexmap 2.12.1", "pin-project-lite", "slab", "sync_wrapper", @@ -5757,9 +5755,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.6" +version = "0.6.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adc82fd73de2a9722ac5da747f12383d2bfdb93591ee6c58486e0097890f05f2" +checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" dependencies = [ "bitflags 2.10.0", "bytes", @@ -5809,9 +5807,9 @@ dependencies = [ [[package]] name = "tracing-attributes" -version = "0.1.30" +version = "0.1.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" dependencies = [ "proc-macro2", "quote", @@ -5820,9 +5818,9 @@ dependencies = [ [[package]] name = "tracing-core" -version = "0.1.34" +version = "0.1.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", "valuable", @@ -5868,12 +5866,6 @@ dependencies = [ "syn", ] -[[package]] -name = "trim-in-place" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "343e926fc669bc8cde4fa3129ab681c63671bae288b1f1081ceee6d9d37904fc" - [[package]] name = "try-lock" version = "0.2.5" @@ -6024,7 +6016,7 @@ version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993" dependencies = [ - "indexmap 2.12.0", + "indexmap 2.12.1", "serde", "serde_json", "utoipa-gen", @@ -6718,9 +6710,9 @@ checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" [[package]] name = "winnow" -version = "0.7.13" +version = "0.7.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21a0236b59786fed61e2a80582dd500fe61f18b5dca67a4a067d0bc9039339cf" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" dependencies = [ "memchr", ] @@ -6809,18 +6801,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.27" +version = "0.8.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.27" +version = "0.8.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" dependencies = [ "proc-macro2", "quote", @@ -6910,7 +6902,7 @@ dependencies = [ "arbitrary", "crc32fast", "flate2", - "indexmap 2.12.0", + "indexmap 2.12.1", "memchr", "zopfli", ] diff --git a/Cargo.toml b/Cargo.toml index 673c673cc..2cb6ad1bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,8 +61,7 @@ pulldown-cmark = "0.13" rand = "0.8" reqwest = { version = "0.12", features = ["json"] } rsa = "0.9" -# 0.21.2 causes config parsing errors -rust-ini = "=0.21.1" +rust-ini = "0.21" semver = { version = "1.0", features = ["serde"] } secrecy = { version = "0.10", features = ["serde"] } serde = { version = "1.0", features = ["derive"] } diff --git a/crates/defguard_core/src/db/models/gateway.rs b/crates/defguard_core/src/db/models/gateway.rs new file mode 100644 index 000000000..5d6221f6b --- /dev/null +++ b/crates/defguard_core/src/db/models/gateway.rs @@ -0,0 +1,105 @@ +use std::fmt; + +use chrono::{NaiveDateTime, Utc}; +use model_derive::Model; +use sqlx::{PgExecutor, query, query_as}; + +use defguard_common::db::{Id, NoId}; + +#[derive(Clone, Debug, Deserialize, Model, PartialEq, Serialize)] +pub(crate) struct Gateway { + pub id: I, + pub network_id: Id, + pub url: String, + pub hostname: Option, + pub connected_at: Option, + pub disconnected_at: Option, +} + +impl Gateway { + #[must_use] + pub(crate) fn new>(network_id: Id, url: S) -> Self { + Self { + id: NoId, + network_id, + url: url.into(), + hostname: None, + connected_at: None, + disconnected_at: None, + } + } +} + +impl Gateway { + pub(crate) async fn find_by_network_id<'e, E>( + executor: E, + network_id: Id, + ) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + query_as!( + Self, + "SELECT * FROM gateway WHERE network_id = $1 ORDER BY id", + network_id + ) + .fetch_all(executor) + .await + } + + /// Update `hostname` and set `connected_at` to the current time and save it to the database. + pub(crate) async fn touch_connected<'e, E>( + &mut self, + executor: E, + hostname: String, + ) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.hostname = Some(hostname); + self.connected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE gateway SET hostname = $2, connected_at = $3 WHERE id = $1", + self.id, + self.hostname, + self.connected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + + /// Set `disconnected_at` to the current time and save it to the database. + pub(crate) async fn touch_disconnected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.disconnected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE gateway SET disconnected_at = $2 WHERE id = $1", + self.id, + self.disconnected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + + pub(crate) fn is_connected(&self) -> bool { + if let (Some(connected_at), Some(disconnected_at)) = + (self.connected_at, self.disconnected_at) + { + disconnected_at <= connected_at + } else { + self.connected_at.is_some() + } + } +} + +impl fmt::Display for Gateway { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Gateway(ID {}; URL {})", self.id, self.url) + } +} diff --git a/crates/defguard_core/src/db/models/mod.rs b/crates/defguard_core/src/db/models/mod.rs index df2faac41..0f9061f45 100644 --- a/crates/defguard_core/src/db/models/mod.rs +++ b/crates/defguard_core/src/db/models/mod.rs @@ -1,6 +1,7 @@ pub mod activity_log; pub mod device; pub mod enrollment; +pub mod gateway; pub mod group; pub mod oauth2authorizedapp; pub mod oauth2client; diff --git a/crates/defguard_core/src/db/models/polling_token.rs b/crates/defguard_core/src/db/models/polling_token.rs index b4d911936..6c2353587 100644 --- a/crates/defguard_core/src/db/models/polling_token.rs +++ b/crates/defguard_core/src/db/models/polling_token.rs @@ -4,7 +4,7 @@ use defguard_common::{ random::gen_alphanumeric, }; use model_derive::Model; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as}; +use sqlx::{PgExecutor, query_as}; // Token used for polling requests. #[derive(Clone, Debug, Model)] @@ -28,18 +28,21 @@ impl PollingToken { } impl PollingToken { - pub async fn find(pool: &PgPool, token: &str) -> Result, SqlxError> { + pub async fn find<'e, E>(executor: E, token: &str) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { query_as!( Self, "SELECT id, token, device_id, created_at \ FROM pollingtoken WHERE token = $1", token ) - .fetch_optional(pool) + .fetch_optional(executor) .await } - pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), SqlxError> + pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), sqlx::Error> where E: PgExecutor<'e>, { diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 33c26e498..2e559dcbf 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -23,8 +23,8 @@ use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; use sqlx::{ - Error as SqlxError, FromRow, PgConnection, PgExecutor, PgPool, Type, - postgres::types::PgInterval, query_as, query_scalar, + FromRow, PgConnection, PgExecutor, PgPool, Type, + postgres::types::PgInterval, query, query_as, query_scalar, }; use thiserror::Error; use tokio::sync::broadcast::Sender; @@ -934,13 +934,15 @@ impl WireguardNetwork { &self, conn: &PgPool, device_id: Id, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { // Find a first handshake gap longer than WIREGUARD_MAX_HANDSHAKE. // We assume that this gap indicates a time when the device was not connected. // So, the handshake after this gap is the moment the last connection was established. - // If no such gap is found, the device may be connected from the beginning, return the first handshake in this case. + // If no such gap is found, the device may be connected from the beginning, return the first + // handshake in this case. let connected_at = query_scalar!( - "WITH stats AS (SELECT * FROM wireguard_peer_stats_view WHERE device_id = $1 AND network = $2) \ + "WITH stats AS \ + (SELECT * FROM wireguard_peer_stats_view WHERE device_id = $1 AND network = $2) \ SELECT \ COALESCE( \ ( \ @@ -964,6 +966,85 @@ impl WireguardNetwork { Ok(connected_at) } + /// Get a list of all allowed peers + /// + /// Each device is marked as allowed or not allowed in a given network, + /// which enables enforcing peer disconnect in MFA-protected networks. + /// + /// If the location is a service location, only returns peers if enterprise features are enabled. + pub async fn get_peers<'e, E>(&self, executor: E) -> Result, sqlx::Error> + where + E: PgExecutor<'e>, + { + debug!("Fetching all peers for network {}", self.id); + + if self.should_prevent_service_location_usage() { + warn!( + "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", + self.name + ); + return Ok(Vec::new()); + } + + let rows = query!( + "SELECT d.wireguard_pubkey pubkey, preshared_key, \ + -- TODO possible to not use ARRAY-unnest here? + ARRAY( + SELECT host(ip) + FROM unnest(wnd.wireguard_ips) AS ip + ) \"allowed_ips!: Vec\" \ + FROM wireguard_network_device wnd \ + JOIN device d ON wnd.device_id = d.id \ + JOIN \"user\" u ON d.user_id = u.id \ + WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ + AND d.configured = true \ + AND u.is_active = true \ + ORDER BY d.id ASC", + self.id, + self.mfa_enabled() + ) + .fetch_all(executor) + .await?; + + // keepalive has to be added manually because Postgres + // doesn't support unsigned integers + let result = rows + .into_iter() + .map(|row| Peer { + pubkey: row.pubkey, + allowed_ips: row.allowed_ips, + // Don't send preshared key if MFA is not enabled, it can't be used and may + // cause issues with clients connecting if they expect no preshared key + // e.g. when you disable MFA on a location + preshared_key: if self.mfa_enabled() { + row.preshared_key + } else { + None + }, + keepalive_interval: Some(self.keepalive_interval as u32), + }) + .collect(); + + Ok(result) + } + + /// Update `connected_at` to the current time and save it to the database. + pub(crate) async fn touch_connected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error> + where + E: PgExecutor<'e>, + { + self.connected_at = Some(Utc::now().naive_utc()); + query!( + "UPDATE wireguard_network SET connected_at = $2 WHERE name = $1", + self.name, + self.connected_at + ) + .execute(executor) + .await?; + + Ok(()) + } + /// Retrieves stats for specified devices pub(crate) async fn device_stats( &self, @@ -971,7 +1052,7 @@ impl WireguardNetwork { devices: &[Device], from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { if devices.is_empty() { return Ok(Vec::new()); } @@ -1036,7 +1117,7 @@ impl WireguardNetwork { from: &NaiveDateTime, aggregation: &DateTimeAggregation, device_type: DeviceType, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let oldest_handshake = (Utc::now() - WIREGUARD_MAX_HANDSHAKE).naive_utc(); // Retrieve connected devices from database let devices = query_as!( @@ -1062,7 +1143,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let mut user_map: HashMap> = HashMap::new(); // Retrieve data series for all active devices and assign them to users let device_stats = self @@ -1076,7 +1157,7 @@ impl WireguardNetwork { for u in user_map { let user = User::find_by_id(conn, u.0) .await? - .ok_or(SqlxError::RowNotFound)?; + .ok_or(sqlx::Error::RowNotFound)?; stats.push(WireguardUserStatsRow { user: UserInfo::from_user(conn, &user).await?, devices: u.1.clone(), @@ -1091,7 +1172,7 @@ impl WireguardNetwork { &self, conn: &PgPool, from: &NaiveDateTime, - ) -> Result { + ) -> Result { let activity_stats = query_as!( WireguardNetworkActivityStats, "SELECT \ @@ -1115,7 +1196,7 @@ impl WireguardNetwork { async fn current_activity( &self, conn: &PgPool, - ) -> Result { + ) -> Result { let from = (Utc::now() - WIREGUARD_MAX_HANDSHAKE).naive_utc(); let activity_stats = query_as!( WireguardNetworkActivityStats, @@ -1143,7 +1224,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result, SqlxError> { + ) -> Result, sqlx::Error> { let stats = query_as!( WireguardStatsRow, "SELECT \ @@ -1171,7 +1252,7 @@ impl WireguardNetwork { conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, - ) -> Result { + ) -> Result { let total_activity = self.total_activity(conn, from).await?; let current_activity = self.current_activity(conn).await?; let transfer_series = self.transfer_series(conn, from, aggregation).await?; @@ -1192,7 +1273,7 @@ impl WireguardNetwork { &self, executor: E, device_type: DeviceType, - ) -> Result>, SqlxError> + ) -> Result>, sqlx::Error> where E: PgExecutor<'e>, { @@ -1432,7 +1513,7 @@ pub(crate) async fn networks_stats( conn: &PgPool, from: &NaiveDateTime, aggregation: &DateTimeAggregation, -) -> Result { +) -> Result { let total_activity = query_as!( WireguardNetworkActivityStats, "SELECT \ diff --git a/crates/defguard_core/src/enterprise/firewall/mod.rs b/crates/defguard_core/src/enterprise/firewall/mod.rs index 5e2b7e8d9..44ed70b0a 100644 --- a/crates/defguard_core/src/enterprise/firewall/mod.rs +++ b/crates/defguard_core/src/enterprise/firewall/mod.rs @@ -896,7 +896,7 @@ impl WireguardNetwork { Ok(rules_info) } - /// Prepares firewall configuration for a gateway based on location config and ACLs + /// Prepares firewall configuration for Gateway based on location config and ACLs. /// Returns `None` if firewall management is disabled for a given location. pub async fn try_get_firewall_config( &self, diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs new file mode 100644 index 000000000..66b9d07aa --- /dev/null +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -0,0 +1,368 @@ +use std::{ + str::FromStr, + sync::atomic::{AtomicU64, Ordering}, +}; + +use defguard_common::{auth::claims::Claims, db::Id}; +use defguard_mail::Mail; +use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; +use sqlx::PgPool; +use tokio::{ + sync::mpsc::{self, Sender, UnboundedSender}, + time::sleep, +}; +use tokio_stream::wrappers::UnboundedReceiverStream; +use tonic::{ + Code, Status, + transport::{ClientTlsConfig, Endpoint}, +}; + +use crate::{ + ClaimsType, + db::{ + Device, GatewayEvent, WireguardNetwork, + models::{gateway::Gateway, wireguard_peer_stats::WireguardPeerStats}, + }, + grpc::TEN_SECS, + handlers::mail::send_gateway_disconnected_email, +}; + +/// One instance per connected gateway. +pub(super) struct GatewayHandler { + endpoint: Endpoint, + gateway: Gateway, + message_id: AtomicU64, + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, +} + +impl GatewayHandler { + pub(super) fn new( + gateway: Gateway, + tls_config: Option, + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, + ) -> Result { + let endpoint = Endpoint::from_shared(gateway.url.to_string())? + .http2_keep_alive_interval(TEN_SECS) + .tcp_keepalive(Some(TEN_SECS)) + .keep_alive_while_idle(true); + let endpoint = if let Some(tls) = tls_config { + endpoint.tls_config(tls)? + } else { + endpoint + }; + + Ok(Self { + endpoint, + gateway, + message_id: AtomicU64::new(0), + pool, + events_tx, + mail_tx, + }) + } + + /// Send network and VPN configuration to Gateway. + async fn send_configuration(&self, tx: &UnboundedSender) -> Result<(), Status> { + debug!("Sending configuration to Gateway"); + let network_id = self.gateway.network_id; + // let hostname = Self::get_gateway_hostname(request.metadata())?; + + let mut conn = self.pool.acquire().await.map_err(|err| { + error!("Failed to acquire DB connection: {err}"); + Status::new( + Code::Internal, + "Failed to acquire database connection".to_string(), + ) + })?; + + let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) + .await + .map_err(|err| { + error!("Network {network_id} not found"); + Status::new(Code::Internal, format!("Failed to retrieve network: {err}")) + })? + .ok_or_else(|| { + Status::new( + Code::Internal, + format!("Network with id {network_id} not found"), + ) + })?; + + debug!( + "Sending configuration to {}, network {network}", + self.gateway + ); + if let Err(err) = network.touch_connected(&mut *conn).await { + error!( + "Failed to update connection time for network {network_id} in the database, \ + status: {err}" + ); + } + + let peers = network.get_peers(&self.pool).await.map_err(|error| { + error!("Failed to fetch peers from the database for network {network_id}: {error}",); + Status::new( + Code::Internal, + format!("Failed to retrieve peers from the database for network: {network_id}"), + ) + })?; + + let maybe_firewall_config = + network + .try_get_firewall_config(&mut *conn) + .await + .map_err(|err| { + error!("Failed to generate firewall config for network {network_id}: {err}"); + Status::new( + Code::Internal, + format!("Failed to generate firewall config for network: {network_id}"), + ) + })?; + let payload = Some(core_response::Payload::Config(super::gen_config( + &network, + peers, + maybe_firewall_config, + ))); + let id = self.message_id.fetch_add(1, Ordering::Relaxed); + let req = CoreResponse { id, payload }; + match tx.send(req) { + Ok(()) => { + info!("Configuration sent to {}, network {network}", self.gateway); + Ok(()) + } + Err(err) => { + error!("Failed to send configuration sent to {}", self.gateway); + Err(Status::new( + Code::Internal, + format!("Configuration not sent to {}, error {err}", self.gateway), + )) + } + } + } + + /// Send gateway disconnected notification. + /// Sends notification only if last notification time is bigger than specified in config. + async fn send_disconnect_notification(&self) { + debug!("Sending gateway disconnect email notification"); + let hostname = self.gateway.hostname.clone(); + let mail_tx = self.mail_tx.clone(); + let pool = self.pool.clone(); + let url = self.gateway.url.clone(); + + let Ok(Some(network)) = + WireguardNetwork::find_by_id(&self.pool, self.gateway.network_id).await + else { + error!( + "Failed to fetch network ID {} from database", + self.gateway.network_id + ); + return; + }; + + // Send email only if disconnection time is before the connection time. + let send_email = if let (Some(connected_at), Some(disconnected_at)) = + (self.gateway.connected_at, self.gateway.disconnected_at) + { + disconnected_at <= connected_at + } else { + true + }; + if send_email { + // FIXME: Try to get rid of spawn and use something like block_on + // To return result instead of logging + tokio::spawn(async move { + if let Err(err) = + send_gateway_disconnected_email(hostname, network.name, &url, &mail_tx, &pool) + .await + { + error!("Failed to send gateway disconnect notification: {err}"); + } else { + info!("Email notification sent about gateway being disconnected"); + } + }); + } else { + info!( + "{} disconnected. Email notification not sent.", + self.gateway + ); + }; + } + + /// Connect to Gateway and handle its messages through gRPC. + pub(super) async fn handle_connection(&mut self) -> ! { + let uri = self.endpoint.uri(); + loop { + #[cfg(not(test))] + let channel = self.endpoint.connect_lazy(); + #[cfg(test)] + let channel = self.endpoint.connect_with_connector_lazy(tower::service_fn( + |_: tonic::transport::Uri| async { + Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( + tokio::net::UnixStream::connect(super::tests::TONIC_SOCKET).await?, + )) + }, + )); + + debug!("Connecting to Gateway {uri}"); + let mut client = gateway_client::GatewayClient::new(channel); + let (tx, rx) = mpsc::unbounded_channel(); + let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { + Ok(response) => response, + Err(err) => { + error!("Failed to connect to gateway {uri}, retrying: {err}"); + sleep(TEN_SECS).await; + continue; + } + }; + + info!("Connected to Defguard Gateway {uri}"); + let mut resp_stream = response.into_inner(); + let mut config_sent = false; + + 'message: loop { + match resp_stream.message().await { + Ok(None) => { + info!("stream was closed by the sender"); + break 'message; + } + Ok(Some(received)) => { + info!("Received message from gateway."); + debug!("Message from Gateway {uri}"); + match received.payload { + Some(core_request::Payload::ConfigRequest(config_request)) => { + if config_sent { + warn!( + "Ignoring repeated configuration request from {}", + self.gateway + ); + continue; + } + // Validate authorization token. + if let Ok(claims) = Claims::from_jwt( + ClaimsType::Gateway, + &config_request.auth_token, + ) { + if let Ok(client_id) = Id::from_str(&claims.client_id) { + if client_id == self.gateway.network_id { + debug!( + "Authorization token is correct for {}", + self.gateway + ); + } else { + warn!( + "Authorization token received from {uri} has \ + `client_id` for a different network" + ); + continue; + } + } else { + warn!( + "Authorization token received from {uri} has incorrect \ + `client_id`" + ); + continue; + } + } else { + warn!("Invalid authorization token received from {uri}"); + continue; + } + + // Send network configuration to Gateway. + match self.send_configuration(&tx).await { + Ok(()) => { + info!("Sent configuration to {}", self.gateway); + config_sent = true; + let _ = self + .gateway + .touch_connected(&self.pool, config_request.hostname) + .await; + } + Err(err) => { + error!( + "Failed to send configuration to {}: {err}", + self.gateway + ); + } + } + + // Start observing configuration changes. + let Ok(Some(network)) = WireguardNetwork::find_by_id( + &self.pool, + self.gateway.network_id, + ) + .await + else { + error!( + "Failed to fetch network ID {} from the database", + self.gateway.network_id + ); + continue; + }; + // tokio::spawn(super::handle_events( + // network, + // tx.clone(), + // self.events_tx.subscribe(), + // )); + } + Some(core_request::Payload::PeerStats(peer_stats)) => { + if !config_sent { + warn!( + "Ignoring peer statistics from {} because it didn't \ + authorize itself", + self.gateway + ); + continue; + } + + // let public_key = peer_stats.public_key.clone(); + // let mut stats = WireguardPeerStats::from_peer_stats( + // peer_stats, + // self.gateway.network_id, + + // ); + // // Get device by public key and fill in stats.device_id + // match Device::find_by_pubkey(&self.pool, &public_key).await { + // Ok(Some(device)) => { + // stats.device_id = device.id; + // match stats.save(&self.pool).await { + // Ok(_) => { + // info!("Saved WireGuard peer stats to database.") + // } + // Err(err) => error!( + // "Failed to save WireGuard peer stats to database: \ + // {err}" + // ), + // } + // } + // Ok(None) => { + // error!("Device with public key {public_key} not found"); + // } + // Err(err) => { + // error!( + // "Failed to retrieve device with public key \ + // {public_key}: {err}", + // ); + // } + // }; + } + None => (), + }; + } + Err(err) => { + error!("Disconnected from gateway at {uri}, error: {err}"); + // Important: call this funtion before setting disconnection time. + self.send_disconnect_notification().await; + let _ = self.gateway.touch_disconnected(&self.pool).await; + debug!("Waiting 10s to re-establish the connection"); + sleep(TEN_SECS).await; + break 'message; + } + } + } + } + } +} diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ff119fc0f..b94db32f1 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,7 +1,7 @@ use std::{ net::{IpAddr, SocketAddr}, pin::Pin, - sync::{Arc, Mutex}, + sync::{Arc, Mutex, MutexGuard}, task::{Context, Poll}, }; @@ -12,13 +12,13 @@ use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{ - Configuration, ConfigurationRequest, Peer, PeerStats, StatsUpdate, Update, - gateway_service_server, stats_update, update, + Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, UpdateType, + core_response, update, }, }; use defguard_version::version_info_from_metadata; use semver::Version; -use sqlx::{Error as SqlxError, PgExecutor, PgPool, query}; +use sqlx::PgPool; use thiserror::Error; use tokio::{ sync::{ @@ -41,8 +41,11 @@ use crate::{ }; pub mod client_state; +pub(crate) mod handler; pub mod map; pub(crate) mod state; +#[cfg(test)] +mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; @@ -90,70 +93,6 @@ pub struct GatewayServer { grpc_event_tx: UnboundedSender, } -impl WireguardNetwork { - /// Get a list of all allowed peers - /// - /// Each device is marked as allowed or not allowed in a given network, - /// which enables enforcing peer disconnect in MFA-protected networks. - /// - /// If the location is a service location, only returns peers if enterprise features are enabled. - pub async fn get_peers<'e, E>(&self, executor: E) -> Result, SqlxError> - where - E: PgExecutor<'e>, - { - debug!("Fetching all peers for network {}", self.id); - - if self.should_prevent_service_location_usage() { - warn!( - "Tried to use service location {} with disabled enterprise features. No clients will be allowed to connect.", - self.name - ); - return Ok(Vec::new()); - } - - let rows = query!( - "SELECT d.wireguard_pubkey pubkey, preshared_key, \ - -- TODO possible to not use ARRAY-unnest here? - ARRAY( - SELECT host(ip) - FROM unnest(wnd.wireguard_ips) AS ip - ) \"allowed_ips!: Vec\" \ - FROM wireguard_network_device wnd \ - JOIN device d ON wnd.device_id = d.id \ - JOIN \"user\" u ON d.user_id = u.id \ - WHERE wireguard_network_id = $1 AND (is_authorized = true OR NOT $2) \ - AND d.configured = true \ - AND u.is_active = true \ - ORDER BY d.id ASC", - self.id, - self.mfa_enabled() - ) - .fetch_all(executor) - .await?; - - // keepalive has to be added manually because Postgres - // doesn't support unsigned integers - let result = rows - .into_iter() - .map(|row| Peer { - pubkey: row.pubkey, - allowed_ips: row.allowed_ips, - // Don't send preshared key if MFA is not enabled, it can't be used and may - // cause issues with clients connecting if they expect no preshared key - // e.g. when you disable MFA on a location - preshared_key: if self.mfa_enabled() { - row.preshared_key - } else { - None - }, - keepalive_interval: Some(self.keepalive_interval as u32), - }) - .collect(); - - Ok(result) - } -} - /// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. struct GatewayMetadata { network_id: Id, @@ -224,9 +163,7 @@ impl GatewayServer { } } - pub fn get_client_state_guard( - &self, - ) -> Result, GatewayServerError> { + pub fn get_client_state_guard(&self) -> Result, GatewayServerError> { let client_state = self .client_state .lock() @@ -354,6 +291,167 @@ impl WireguardPeerStats { } } +/* + +/// Process received Gateway events +/// +/// Main gRPC server uses a shared channel for broadcasting all Gateway events, +/// so the handler must determine if an event is relevant for the network being serviced. +async fn handle_events( + mut current_network: WireguardNetwork, + tx: UnboundedSender, + mut events_rx: Receiver, +) { + info!("Starting update stream network {current_network}"); + while let Some(event) = events_rx.recv().await { + debug!("Received networking state update event: {event:?}"); + let (update_type, update) = match event { + GatewayEvent::NetworkCreated(network, _fixme) => { + if network.id != current_network.id { + continue; + } + ( + UpdateType::Create, + update::Update::Network(Configuration { + name: network.name.clone(), + prvkey: network.prvkey.clone(), + addresses: network.address.to_string(), + port: network.port as u32, + peers: Vec::new(), + }), + ) + } + GatewayEvent::NetworkModified(network, peers, _fixme) => { + if network.id != current_network.id { + continue; + } + // update stored network data + current_network = network.clone(); + ( + UpdateType::Modify, + update::Update::Network(Configuration { + name: network.name, + prvkey: network.prvkey, + addresses: network.address.to_string(), + port: network.port as u32, + peers, + }), + ) + } + GatewayEvent::NetworkDeleted(network_id, network_name) => { + if network_id != current_network.id { + continue; + } + ( + UpdateType::Delete, + update::Update::Network(Configuration { + name: network_name.to_string(), + prvkey: String::new(), + addresses: Vec::new(), + port: 0, + peers: Vec::new(), + firewall_config: None, + }), + ) + } + GatewayEvent::DeviceCreated(device) => { + // check if a peer has to be added in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(network_info) => { + if current_network.mfa_enabled && !network_info.is_authorized { + debug!( + "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + device.device.name, current_network.name + ); + continue; + }; + let peer = Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + preshared_key: network_info.preshared_key.clone(), + keepalive_interval: Some(current_network.keepalive_interval as u32), + }; + (UpdateType::Create, update::Update::Peer(peer)) + } + None => continue, + } + } + GatewayEvent::DeviceModified(device) => { + // check if a peer has to be updated in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(network_info) => { + if current_network.mfa_enabled && !network_info.is_authorized { + debug!( + "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + device.device.name, current_network.name + ); + continue; + }; + let peer = Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + preshared_key: network_info.preshared_key.clone(), + keepalive_interval: Some(current_network.keepalive_interval as u32), + }; + (UpdateType::Modify, update::Update::Peer(peer)) + } + None => continue, + } + } + GatewayEvent::DeviceDeleted(device) => { + // check if a peer has to be updated in the current network + match device + .network_info + .iter() + .find(|info| info.network_id == current_network.id) + { + Some(_) => ( + UpdateType::Delete, + update::Update::Peer(Peer { + pubkey: device.device.wireguard_pubkey, + allowed_ips: Vec::new(), + preshared_key: None, + keepalive_interval: None, + }), + ), + None => continue, + } + } + GatewayEvent::FirewallConfigChanged(_fixme, _) => (), + GatewayEvent::FirewallDisabled(_id) => (), + }; + + let req = CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { + update_type: update_type as i32, + update: Some(update), + })), + }; + if let Err(err) = tx.send(req) { + error!( + "Failed to send network update, network {current_network}, update type: {}, error: \ + {err}", + update_type.as_str_name() + ); + break; + } + debug!( + "Network update sent for network {current_network}, update type: {}", + update_type.as_str_name() + ); + } +} +*/ + /// Helper struct for handling gateway events struct GatewayUpdatesHandler { network_id: Id, @@ -751,334 +849,334 @@ impl Drop for GatewayUpdatesStream { } } -#[tonic::async_trait] -impl gateway_service_server::GatewayService for GatewayServer { - type UpdatesStream = GatewayUpdatesStream; - - /// Retrieve stats from gateway and save it to database - async fn stats( - &self, - request: Request>, - ) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - } = Self::extract_metadata(request.metadata())?; - let mut stream = request.into_inner(); - let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - loop { - // Wait for a message or update client map at least once a mninute, if no messages are - // received. - let stats_update = tokio::select! { - message = stream.message() => { - match message? { - Some(update) => update, - None => break, // Stream ended - } - } - _ = disconnect_timer.tick() => { - debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ - Updating disconnected VPN clients"); - // fetch location to get current peer disconnect threshold - let location = self.fetch_location_from_db(network_id).await?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location - )? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - }; - continue; - } - }; - - debug!("Received stats message: {stats_update:?}"); - let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { - debug!("Received stats message is empty, skipping."); - continue; - }; - let public_key = peer_stats.public_key.clone(); - - // fetch device from DB - // TODO: fetch only when device has changed and use client state otherwise - let device = match self.fetch_device_from_db(&public_key).await? { - Some(device) => device, - None => { - warn!( - "Received stats update for a device which does not exist: {public_key}, skipping." - ); - continue; - } - }; - - // copy device ID for easier reference later - let device_id = device.id; - - // fetch user and location from DB for activity log - // TODO: cache usernames since they don't change - let user = self.fetch_user_from_db(device.user_id, &public_key).await?; - let location = self.fetch_location_from_db(network_id).await?; - - // convert stats to DB storage format - let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); - - // only perform client state update if stats include an endpoint IP - // otherwise a peer was added to the gateway interface - // but has not connected yet - if let Some(endpoint) = &stats.endpoint { - // parse client endpoint IP - let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { - error!("Failed to parse VPN client endpoint: {err}"); - Status::new( - Code::Internal, - format!("Failed to parse VPN client endpoint: {err}"), - ) - })?; - - // perform client state operations in a dedicated block to drop mutex guard - let disconnected_clients = { - // acquire lock on client state map - let mut client_map = self.get_client_state_guard()?; - - // update connected clients map - match client_map.get_vpn_client(network_id, &public_key) { - Some(client_state) => { - // update connected client state - client_state.update_client_state( - device, - socket_addr, - stats.latest_handshake, - stats.upload, - stats.download, - ); - } - None => { - // don't mark inactive peers as connected - if (Utc::now().naive_utc() - stats.latest_handshake) - < TimeDelta::seconds(location.peer_disconnect_threshold.into()) - { - // mark new VPN client as connected - client_map.connect_vpn_client( - network_id, - &hostname, - &public_key, - &device, - &user, - socket_addr, - &stats, - )?; - - // emit connection event - let context = GrpcRequestContext::new( - user.id, - user.username.clone(), - socket_addr.ip(), - device.id, - device.name.clone(), - location.clone(), - ); - self.emit_event(GrpcEvent::ClientConnected { - context, - location: location.clone(), - device: device.clone(), - })?; - } - } - } - - // disconnect inactive clients - client_map.disconnect_inactive_vpn_clients_for_location(&location)? - }; - - // emit client disconnect events - for (device, context) in disconnected_clients { - self.emit_event(GrpcEvent::ClientDisconnected { - context, - location: location.clone(), - device, - })?; - } - } - - // Save stats to db - let stats = match stats.save(&self.pool).await { - Ok(stats) => stats, - Err(err) => { - error!("Saving WireGuard peer stats to db failed: {err}"); - return Err(Status::new( - Code::Internal, - format!("Saving WireGuard peer stats to db failed: {err}"), - )); - } - }; - info!("Saved WireGuard peer stats to db."); - debug!("WireGuard peer stats: {stats:?}"); - } - - Ok(Response::new(())) - } - - async fn config( - &self, - request: Request, - ) -> Result, Status> { - debug!("Sending configuration to gateway client."); - let GatewayMetadata { - network_id, - hostname, - version, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let mut conn = self.pool.acquire().await.map_err(|e| { - error!("Failed to acquire DB connection: {e}"); - Status::new( - Code::Internal, - "Failed to acquire DB connection".to_string(), - ) - })?; - - let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) - .await - .map_err(|e| { - error!("Network {network_id} not found"); - Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) - })? - .ok_or_else(|| { - Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - ) - })?; - - debug!("Sending configuration to gateway client, network {network}."); - - // store connected gateway in memory - { - let mut state = self.gateway_state.lock().unwrap(); - state.add_gateway( - network_id, - &network.name, - hostname, - request.into_inner().name, - self.mail_tx.clone(), - version, - ); - } - - network.connected_at = Some(Utc::now().naive_utc()); - if let Err(err) = network.save(&mut *conn).await { - error!("Failed to save updated network {network_id} in the database, status: {err}"); - } - - let peers = network.get_peers(&mut *conn).await.map_err(|error| { - error!("Failed to fetch peers from the database for network {network_id}: {error}",); - Status::new( - Code::Internal, - format!("Failed to retrieve peers from the database for network: {network_id}"), - ) - })?; - let maybe_firewall_config = - network - .try_get_firewall_config(&mut conn) - .await - .map_err(|err| { - error!("Failed to generate firewall config for network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to generate firewall config for network: {network_id}"), - ) - })?; - - info!("Configuration sent to gateway client, network {network}."); - - Ok(Response::new(gen_config( - &network, - peers, - maybe_firewall_config, - ))) - } - - async fn updates(&self, request: Request<()>) -> Result, Status> { - let GatewayMetadata { - network_id, - hostname, - .. - // info, - } = Self::extract_metadata(request.metadata())?; - // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. - // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, - // version = version.to_string(), info); - // let _guard = span.enter(); - - let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) - .await - .map_err(|_| { - error!("Failed to fetch network {network_id} from the database"); - Status::new( - Code::Internal, - format!("Failed to retrieve network {network_id} from the database"), - ) - })? - else { - return Err(Status::new( - Code::Internal, - format!("Network with id {network_id} not found"), - )); - }; - - info!("New client connected to updates stream: {hostname}, network {network}",); - - let (tx, rx) = mpsc::channel(4); - let events_rx = self.wireguard_tx.subscribe(); - let mut state = self.gateway_state.lock().unwrap(); - state - .connect_gateway(network_id, &hostname, &self.pool) - .map_err(|err| { - error!("Failed to connect gateway on network {network_id}: {err}"); - Status::new( - Code::Internal, - format!("Failed to connect gateway on network {network_id}"), - ) - })?; - - // clone here before moving into a closure - let gateway_hostname = hostname.clone(); - let handle = tokio::spawn(async move { - let mut update_handler = - GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); - update_handler.run().await; - }); - - Ok(Response::new(GatewayUpdatesStream::new( - handle, - rx, - network_id, - hostname, - Arc::clone(&self.gateway_state), - self.pool.clone(), - ))) - } -} +// #[tonic::async_trait] +// impl gateway_service_server::GatewayService for GatewayServer { +// type UpdatesStream = GatewayUpdatesStream; + +// /// Retrieve stats from gateway and save it to database +// async fn stats( +// &self, +// request: Request>, +// ) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// } = Self::extract_metadata(request.metadata())?; +// let mut stream = request.into_inner(); +// let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); +// loop { +// // Wait for a message or update client map at least once a mninute, if no messages are +// // received. +// let stats_update = tokio::select! { +// message = stream.message() => { +// match message? { +// Some(update) => update, +// None => break, // Stream ended +// } +// } +// _ = disconnect_timer.tick() => { +// debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ +// Updating disconnected VPN clients"); +// // fetch location to get current peer disconnect threshold +// let location = self.fetch_location_from_db(network_id).await?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location +// )? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// }; +// continue; +// } +// }; + +// debug!("Received stats message: {stats_update:?}"); +// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { +// debug!("Received stats message is empty, skipping."); +// continue; +// }; +// let public_key = peer_stats.public_key.clone(); + +// // fetch device from DB +// // TODO: fetch only when device has changed and use client state otherwise +// let device = match self.fetch_device_from_db(&public_key).await? { +// Some(device) => device, +// None => { +// warn!( +// "Received stats update for a device which does not exist: {public_key}, skipping." +// ); +// continue; +// } +// }; + +// // copy device ID for easier reference later +// let device_id = device.id; + +// // fetch user and location from DB for activity log +// // TODO: cache usernames since they don't change +// let user = self.fetch_user_from_db(device.user_id, &public_key).await?; +// let location = self.fetch_location_from_db(network_id).await?; + +// // convert stats to DB storage format +// let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); + +// // only perform client state update if stats include an endpoint IP +// // otherwise a peer was added to the gateway interface +// // but has not connected yet +// if let Some(endpoint) = &stats.endpoint { +// // parse client endpoint IP +// let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { +// error!("Failed to parse VPN client endpoint: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to parse VPN client endpoint: {err}"), +// ) +// })?; + +// // perform client state operations in a dedicated block to drop mutex guard +// let disconnected_clients = { +// // acquire lock on client state map +// let mut client_map = self.get_client_state_guard()?; + +// // update connected clients map +// match client_map.get_vpn_client(network_id, &public_key) { +// Some(client_state) => { +// // update connected client state +// client_state.update_client_state( +// device, +// socket_addr, +// stats.latest_handshake, +// stats.upload, +// stats.download, +// ); +// } +// None => { +// // don't mark inactive peers as connected +// if (Utc::now().naive_utc() - stats.latest_handshake) +// < TimeDelta::seconds(location.peer_disconnect_threshold.into()) +// { +// // mark new VPN client as connected +// client_map.connect_vpn_client( +// network_id, +// &hostname, +// &public_key, +// &device, +// &user, +// socket_addr, +// &stats, +// )?; + +// // emit connection event +// let context = GrpcRequestContext::new( +// user.id, +// user.username.clone(), +// socket_addr.ip(), +// device.id, +// device.name.clone(), +// location.clone(), +// ); +// self.emit_event(GrpcEvent::ClientConnected { +// context, +// location: location.clone(), +// device: device.clone(), +// })?; +// } +// } +// } + +// // disconnect inactive clients +// client_map.disconnect_inactive_vpn_clients_for_location(&location)? +// }; + +// // emit client disconnect events +// for (device, context) in disconnected_clients { +// self.emit_event(GrpcEvent::ClientDisconnected { +// context, +// location: location.clone(), +// device, +// })?; +// } +// } + +// // Save stats to db +// let stats = match stats.save(&self.pool).await { +// Ok(stats) => stats, +// Err(err) => { +// error!("Saving WireGuard peer stats to db failed: {err}"); +// return Err(Status::new( +// Code::Internal, +// format!("Saving WireGuard peer stats to db failed: {err}"), +// )); +// } +// }; +// info!("Saved WireGuard peer stats to db."); +// debug!("WireGuard peer stats: {stats:?}"); +// } + +// Ok(Response::new(())) +// } + +// async fn config( +// &self, +// request: Request, +// ) -> Result, Status> { +// debug!("Sending configuration to gateway client."); +// let GatewayMetadata { +// network_id, +// hostname, +// version, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let mut conn = self.pool.acquire().await.map_err(|e| { +// error!("Failed to acquire DB connection: {e}"); +// Status::new( +// Code::Internal, +// "Failed to acquire DB connection".to_string(), +// ) +// })?; + +// let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) +// .await +// .map_err(|e| { +// error!("Network {network_id} not found"); +// Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) +// })? +// .ok_or_else(|| { +// Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// ) +// })?; + +// debug!("Sending configuration to gateway client, network {network}."); + +// // store connected gateway in memory +// { +// let mut state = self.gateway_state.lock().unwrap(); +// state.add_gateway( +// network_id, +// &network.name, +// hostname, +// request.into_inner().name, +// self.mail_tx.clone(), +// version, +// ); +// } + +// network.connected_at = Some(Utc::now().naive_utc()); +// if let Err(err) = network.save(&mut *conn).await { +// error!("Failed to save updated network {network_id} in the database, status: {err}"); +// } + +// let peers = network.get_peers(&mut *conn).await.map_err(|error| { +// error!("Failed to fetch peers from the database for network {network_id}: {error}",); +// Status::new( +// Code::Internal, +// format!("Failed to retrieve peers from the database for network: {network_id}"), +// ) +// })?; +// let maybe_firewall_config = +// network +// .try_get_firewall_config(&mut conn) +// .await +// .map_err(|err| { +// error!("Failed to generate firewall config for network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to generate firewall config for network: {network_id}"), +// ) +// })?; + +// info!("Configuration sent to gateway client, network {network}."); + +// Ok(Response::new(gen_config( +// &network, +// peers, +// maybe_firewall_config, +// ))) +// } + +// async fn updates(&self, request: Request<()>) -> Result, Status> { +// let GatewayMetadata { +// network_id, +// hostname, +// .. +// // info, +// } = Self::extract_metadata(request.metadata())?; +// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. +// // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, +// // version = version.to_string(), info); +// // let _guard = span.enter(); + +// let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) +// .await +// .map_err(|_| { +// error!("Failed to fetch network {network_id} from the database"); +// Status::new( +// Code::Internal, +// format!("Failed to retrieve network {network_id} from the database"), +// ) +// })? +// else { +// return Err(Status::new( +// Code::Internal, +// format!("Network with id {network_id} not found"), +// )); +// }; + +// info!("New client connected to updates stream: {hostname}, network {network}",); + +// let (tx, rx) = mpsc::channel(4); +// let events_rx = self.wireguard_tx.subscribe(); +// let mut state = self.gateway_state.lock().unwrap(); +// state +// .connect_gateway(network_id, &hostname, &self.pool) +// .map_err(|err| { +// error!("Failed to connect gateway on network {network_id}: {err}"); +// Status::new( +// Code::Internal, +// format!("Failed to connect gateway on network {network_id}"), +// ) +// })?; + +// // clone here before moving into a closure +// let gateway_hostname = hostname.clone(); +// let handle = tokio::spawn(async move { +// let mut update_handler = +// GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); +// update_handler.run().await; +// }); + +// Ok(Response::new(GatewayUpdatesStream::new( +// handle, +// rx, +// network_id, +// hostname, +// Arc::clone(&self.gateway_state), +// self.pool.clone(), +// ))) +// } +// } diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs new file mode 100644 index 000000000..18174116e --- /dev/null +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -0,0 +1,87 @@ +use std::{ + io, + net::{IpAddr, Ipv4Addr}, +}; + +use ipnetwork::IpNetwork; +use tokio::{ + net::UnixListener, + sync::{broadcast, mpsc::unbounded_channel}, +}; +use tokio_stream::wrappers::UnixListenerStream; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +use super::*; + +pub(super) static TONIC_SOCKET: &str = "tonic.sock"; + +struct FakeGateway; + +#[tonic::async_trait] +impl gateway_server::Gateway for FakeGateway { + type BidiStream = UnboundedReceiverStream>; + + async fn bidi( + &self, + request: Request>, + ) -> Result, Status> { + let (_tx, rx) = mpsc::unbounded_channel(); + let mut stream = request.into_inner(); + tokio::spawn(async move { + loop { + match stream.message().await { + Ok(Some(_response)) => (), + Ok(None) => (), + Err(_err) => (), + } + } + }); + + Ok(Response::new(UnboundedReceiverStream::new(rx))) + } +} + +async fn fake_gateway() -> Result<(), io::Error> { + let gateway = FakeGateway {}; + + let uds = UnixListener::bind(TONIC_SOCKET)?; + let uds_stream = UnixListenerStream::new(uds); + + Server::builder() + .add_service(gateway_server::GatewayServer::new(gateway)) + .serve_with_incoming(uds_stream) + .await + .unwrap(); + + Ok(()) +} + +#[sqlx::test] +async fn test_gateway(pool: PgPool) { + let network = WireguardNetwork::new( + "TestNet".to_string(), + IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap(), + 50051, + "0.0.0.0".to_string(), + None, + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], + false, + 0, + 0, + ) + .save(&pool) + .await + .unwrap(); + let gateway = Gateway::new(network.id, "http://[::]:50051") + .save(&pool) + .await + .unwrap(); + let (events_tx, _events_rx) = broadcast::channel::(16); + let (mail_tx, _mail_rx) = unbounded_channel::(); + + let mut gateway_handler = GatewayHandler::new(gateway, None, pool, events_tx, mail_tx).unwrap(); + let handle = tokio::spawn(async move { + gateway_handler.handle_connection().await; + }); + handle.abort(); +} diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index a4c4ba3dc..f2e98d6ec 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -90,7 +90,6 @@ pub mod proto { use defguard_proto::{ auth::auth_service_server::AuthServiceServer, - gateway::gateway_service_server::GatewayServiceServer, proxy::{ AuthCallbackResponse, AuthInfoResponse, CoreError, CoreRequest, CoreResponse, core_request, core_response, proxy_client::ProxyClient, @@ -734,31 +733,31 @@ pub async fn build_grpc_service_router( .add_service(health_service) .add_service(auth_service); - let router = { - use crate::version::GatewayVersionInterceptor; - - let gateway_service = GatewayServiceServer::new(GatewayServer::new( - pool, - gateway_state, - client_state, - wireguard_tx, - mail_tx, - grpc_event_tx, - )); - - let own_version = Version::parse(VERSION)?; - router.add_service( - ServiceBuilder::new() - .layer(tonic::service::InterceptorLayer::new(JwtInterceptor::new( - ClaimsType::Gateway, - ))) - .layer(tonic::service::InterceptorLayer::new( - GatewayVersionInterceptor::new(MIN_GATEWAY_VERSION, incompatible_components), - )) - .layer(DefguardVersionLayer::new(own_version)) - .service(gateway_service), - ) - }; + // let router = { + // use crate::version::GatewayVersionInterceptor; + + // let gateway_service = GatewayServiceServer::new(GatewayServer::new( + // pool, + // gateway_state, + // client_state, + // wireguard_tx, + // mail_tx, + // grpc_event_tx, + // )); + + // let own_version = Version::parse(VERSION)?; + // router.add_service( + // ServiceBuilder::new() + // .layer(tonic::service::InterceptorLayer::new(JwtInterceptor::new( + // ClaimsType::Gateway, + // ))) + // .layer(tonic::service::InterceptorLayer::new( + // GatewayVersionInterceptor::new(MIN_GATEWAY_VERSION, incompatible_components), + // )) + // .layer(DefguardVersionLayer::new(own_version)) + // .service(gateway_service), + // ) + // }; let router = router.add_service(worker_service); diff --git a/migrations/20251125072923_network_gateways.down.sql b/migrations/20251125072923_network_gateways.down.sql new file mode 100644 index 000000000..5e727c02c --- /dev/null +++ b/migrations/20251125072923_network_gateways.down.sql @@ -0,0 +1,3 @@ +DROP TRIGGER gateway ON gateway; +DROP FUNCTION row_change(); +DROP TABLE gateway; diff --git a/migrations/20251125072923_network_gateways.up.sql b/migrations/20251125072923_network_gateways.up.sql new file mode 100644 index 000000000..3db149fd6 --- /dev/null +++ b/migrations/20251125072923_network_gateways.up.sql @@ -0,0 +1,20 @@ +CREATE TABLE gateway ( + id bigserial PRIMARY KEY, + network_id bigint NOT NULL, + url text NOT NULL, + hostname text NULL, + connected_at timestamp without time zone NULL, + disconnected_at timestamp without time zone NULL, + FOREIGN KEY(network_id) REFERENCES wireguard_network(id) +); +CREATE FUNCTION row_change() RETURNS trigger AS $$ +BEGIN + PERFORM pg_notify(TG_TABLE_NAME || '_change', + json_build_object('operation', TG_OP, 'old', row_to_json(OLD), 'new', row_to_json(NEW))::text + ); + RETURN NULL; +END; +$$ LANGUAGE plpgsql; +CREATE TRIGGER gateway + AFTER INSERT OR UPDATE OR DELETE ON gateway + FOR ROW EXECUTE FUNCTION row_change(); diff --git a/proto b/proto index 74d60d917..d8a8d1b27 160000 --- a/proto +++ b/proto @@ -1 +1 @@ -Subproject commit 74d60d9171048ba0ccaf8a21b05950fb7a673f09 +Subproject commit d8a8d1b27fe38f1bd71241971c90ed3852f06d5b From 2f1af6bc6c23450fdd7e82f64c3b038c7c9c4d97 Mon Sep 17 00:00:00 2001 From: jakub-tldr <78603704+jakub-tldr@users.noreply.github.com> Date: Thu, 27 Nov 2025 13:31:56 +0100 Subject: [PATCH 07/17] RPM config fix (#1730) --- .fpm | 1 + 1 file changed, 1 insertion(+) diff --git a/.fpm b/.fpm index 062ba199b..a03eba34a 100644 --- a/.fpm +++ b/.fpm @@ -3,3 +3,4 @@ --description "Defguard Core service" --url "https://defguard.net/" --maintainer "Defguard" +--config-files /etc/defguard/core.conf From dcdc0f2606614dd231f809560e3a641bac0761b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 28 Nov 2025 11:25:17 +0100 Subject: [PATCH 08/17] Let it build --- Cargo.lock | 40 +++++++++---------- .../defguard_core/src/grpc/gateway/handler.rs | 4 +- crates/defguard_core/src/grpc/gateway/mod.rs | 5 ++- .../defguard_core/src/grpc/gateway/tests.rs | 4 +- crates/defguard_core/src/grpc/mod.rs | 24 +++++------ .../integration/grpc/common/mock_gateway.rs | 4 +- .../tests/integration/grpc/common/mod.rs | 12 +++--- .../tests/integration/grpc/gateway.rs | 2 +- .../defguard_core/tests/integration/main.rs | 2 +- 9 files changed, 48 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 78a69433d..5703ef2bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2569,9 +2569,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -4715,9 +4715,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.16.0" +version = "3.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10574371d41b0d9b2cff89418eda27da52bcaff2cc8741db26382a77c29131f1" +checksum = "4fa237f2807440d238e0364a218270b98f767a00d3dada77b1c53ae88940e2e7" dependencies = [ "base64 0.22.1", "chrono", @@ -4734,9 +4734,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.16.0" +version = "3.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08a72d8216842fdd57820dc78d840bef99248e35fb2554ff923319e60f2d686b" +checksum = "52a8e3ca0ca629121f70ab50f95249e5a6f925cc0f6ffe8256c45b728875706c" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -5795,9 +5795,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -6171,9 +6171,9 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -6184,9 +6184,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -6197,9 +6197,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6207,9 +6207,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", @@ -6220,9 +6220,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] @@ -6242,9 +6242,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 66b9d07aa..497ac064a 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -27,7 +27,7 @@ use crate::{ handlers::mail::send_gateway_disconnected_email, }; -/// One instance per connected gateway. +/// One instance per connected Gateway. pub(super) struct GatewayHandler { endpoint: Endpoint, gateway: Gateway, @@ -202,7 +202,7 @@ impl GatewayHandler { let channel = self.endpoint.connect_with_connector_lazy(tower::service_fn( |_: tonic::transport::Uri| async { Ok::<_, std::io::Error>(hyper_util::rt::TokioIo::new( - tokio::net::UnixStream::connect(super::tests::TONIC_SOCKET).await?, + tokio::net::UnixStream::connect(super::TONIC_SOCKET).await?, )) }, )); diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index b94db32f1..8f1f5ec66 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -44,10 +44,11 @@ pub mod client_state; pub(crate) mod handler; pub mod map; pub(crate) mod state; -#[cfg(test)] -mod tests; +// #[cfg(test)] +// mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; +static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs index 18174116e..f79b77dba 100644 --- a/crates/defguard_core/src/grpc/gateway/tests.rs +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -9,12 +9,10 @@ use tokio::{ sync::{broadcast, mpsc::unbounded_channel}, }; use tokio_stream::wrappers::UnixListenerStream; -use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; use super::*; -pub(super) static TONIC_SOCKET: &str = "tonic.sock"; - struct FakeGateway; #[tonic::async_trait] diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index f2e98d6ec..eedad0a01 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -680,13 +680,13 @@ pub async fn run_grpc_server( server, pool, worker_state, - gateway_state, - client_state, - wireguard_tx, - mail_tx, + // gateway_state, + // client_state, + // wireguard_tx, + // mail_tx, failed_logins, - grpc_event_tx, - incompatible_components, + // grpc_event_tx, + // incompatible_components, ) .await?; @@ -707,13 +707,13 @@ pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, + // gateway_state: Arc>, + // client_state: Arc>, + // wireguard_tx: Sender, + // mail_tx: UnboundedSender, failed_logins: Arc>, - grpc_event_tx: UnboundedSender, - incompatible_components: Arc>, + // grpc_event_tx: UnboundedSender, + // incompatible_components: Arc>, ) -> Result { let auth_service = AuthServiceServer::new(AuthServer::new(pool.clone(), failed_logins)); diff --git a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs index 6440f5e8f..11bcdafbf 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mock_gateway.rs @@ -2,8 +2,8 @@ use std::time::Duration; use defguard_core::grpc::{AUTHORIZATION_HEADER, HOSTNAME_HEADER}; use defguard_proto::gateway::{ - Configuration, ConfigurationRequest, StatsUpdate, Update, - gateway_service_client::GatewayServiceClient, + Configuration, ConfigurationRequest, Update, + }; use defguard_version::{Version, client::ClientVersionInterceptor}; use tokio::{ diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 96609dbfa..b919afcd4 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -28,7 +28,7 @@ use tower::service_fn; use crate::common::{init_config, initialize_users}; -pub mod mock_gateway; +// pub mod mock_gateway; pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, @@ -156,13 +156,13 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { server, pool.clone(), worker_state, - gateway_state.clone(), - client_state.clone(), - wg_tx.clone(), + // gateway_state.clone(), + // client_state.clone(), + // wg_tx.clone(), mail_tx, failed_logins, - grpc_event_tx, - Default::default(), + // grpc_event_tx, + // Default::default(), ) .await .unwrap(); diff --git a/crates/defguard_core/tests/integration/grpc/gateway.rs b/crates/defguard_core/tests/integration/grpc/gateway.rs index d27fca1e7..75c0c8339 100644 --- a/crates/defguard_core/tests/integration/grpc/gateway.rs +++ b/crates/defguard_core/tests/integration/grpc/gateway.rs @@ -21,7 +21,7 @@ use defguard_core::{ }; use defguard_proto::{ enterprise::firewall::FirewallPolicy, - gateway::{Configuration, PeerStats, StatsUpdate, Update, stats_update::Payload, update}, + gateway::{Configuration, PeerStats, Update, stats_update::Payload, update}, }; use semver::Version; use sqlx::{ diff --git a/crates/defguard_core/tests/integration/main.rs b/crates/defguard_core/tests/integration/main.rs index f85d8d0fa..b3793ede2 100644 --- a/crates/defguard_core/tests/integration/main.rs +++ b/crates/defguard_core/tests/integration/main.rs @@ -1,3 +1,3 @@ mod api; mod common; -mod grpc; +// mod grpc; From 9af11aa77cb89b695d9fc9bb20bef2acdeab5611 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Fri, 28 Nov 2025 13:31:44 +0100 Subject: [PATCH 09/17] Database trigger --- crates/defguard/src/main.rs | 7 +- crates/defguard_common/src/config.rs | 26 +++- crates/defguard_common/src/db/mod.rs | 15 +++ .../defguard_core/src/db/models/wireguard.rs | 4 +- .../defguard_core/src/grpc/gateway/handler.rs | 11 +- crates/defguard_core/src/grpc/gateway/mod.rs | 2 +- .../defguard_core/src/grpc/gateway/state.rs | 16 ++- crates/defguard_core/src/grpc/mod.rs | 115 ++++++++++++++++-- 8 files changed, 177 insertions(+), 19 deletions(-) diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 3c7576a2e..5188cfd41 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -24,7 +24,7 @@ use defguard_core::{ grpc::{ WorkerState, gateway::{client_state::ClientMap, map::GatewayMap}, - run_grpc_bidi_stream, run_grpc_server, + run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, @@ -153,6 +153,11 @@ async fn main() -> Result<(), anyhow::Error> { // run services tokio::select! { + res = run_grpc_gateway_stream( + pool.clone(), + wireguard_tx.clone(), + mail_tx.clone() + ) => error!("Gateway gRPC stream returned early: {res:?}"), res = run_grpc_bidi_stream( pool.clone(), wireguard_tx.clone(), diff --git a/crates/defguard_common/src/config.rs b/crates/defguard_common/src/config.rs index 2549ce610..97f8d59b2 100644 --- a/crates/defguard_common/src/config.rs +++ b/crates/defguard_common/src/config.rs @@ -1,4 +1,4 @@ -use std::{net::IpAddr, sync::OnceLock}; +use std::{fs::read_to_string, io, net::IpAddr, sync::OnceLock}; use clap::{Args, Parser, Subcommand}; use humantime::Duration; @@ -13,6 +13,7 @@ use rsa::{ }; use secrecy::{ExposeSecret, SecretString}; use serde::Serialize; +use tonic::transport::{Certificate, ClientTlsConfig, Identity}; pub static SERVER_CONFIG: OnceLock = OnceLock::new(); @@ -65,9 +66,11 @@ pub struct DefGuardConfig { #[arg(long, env = "DEFGUARD_GRPC_PORT", default_value_t = 50055)] pub grpc_port: u16, + // Certificate authority (CA), certificate, and key for gRPC communication over HTTPS. + #[arg(long, env = "DEFGUARD_GRPC_CA")] + pub grpc_ca: Option, #[arg(long, env = "DEFGUARD_GRPC_CERT")] pub grpc_cert: Option, - #[arg(long, env = "DEFGUARD_GRPC_KEY")] pub grpc_key: Option, @@ -298,6 +301,25 @@ impl DefGuardConfig { } url } + + /// Provide [`ClientTlsConfig`] from paths to cerfiticate, key, and cerfiticate authority (CA). + pub fn grpc_client_tls_config(&self) -> Result, io::Error> { + if self.grpc_ca.is_none() && (self.grpc_cert.is_none() || self.grpc_key.is_none()) { + return Ok(None); + } + let mut tls = ClientTlsConfig::new(); + if let (Some(cert_path), Some(key_path)) = (&self.grpc_cert, &self.grpc_key) { + let cert = read_to_string(cert_path)?; + let key = read_to_string(key_path)?; + tls = tls.identity(Identity::from_pem(cert, key)); + } + if let Some(ca_path) = &self.grpc_ca { + let ca = read_to_string(ca_path)?; + tls = tls.ca_certificate(Certificate::from_pem(ca)); + } + + Ok(Some(tls)) + } } impl Default for DefGuardConfig { diff --git a/crates/defguard_common/src/db/mod.rs b/crates/defguard_common/src/db/mod.rs index d7ca63d05..cc49e8289 100644 --- a/crates/defguard_common/src/db/mod.rs +++ b/crates/defguard_common/src/db/mod.rs @@ -45,3 +45,18 @@ pub async fn setup_pool(options: PgConnectOptions) -> PgPool { .expect("Cannot run database migrations."); pool } + +#[derive(Deserialize)] +#[serde(rename_all = "UPPERCASE")] +pub enum TriggerOperation { + Insert, + Update, + Delete, +} + +#[derive(Deserialize)] +pub struct ChangeNotification { + pub operation: TriggerOperation, + pub old: Option, + pub new: Option, +} diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 2e559dcbf..32c4a4e4f 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -23,8 +23,8 @@ use ipnetwork::{IpNetwork, IpNetworkError, NetworkSize}; use model_derive::Model; use rand::rngs::OsRng; use sqlx::{ - FromRow, PgConnection, PgExecutor, PgPool, Type, - postgres::types::PgInterval, query, query_as, query_scalar, + FromRow, PgConnection, PgExecutor, PgPool, Type, postgres::types::PgInterval, query, query_as, + query_scalar, }; use thiserror::Error; use tokio::sync::broadcast::Sender; diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 497ac064a..53403fcc2 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,7 +8,10 @@ use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; use sqlx::PgPool; use tokio::{ - sync::mpsc::{self, Sender, UnboundedSender}, + sync::{ + broadcast::Sender, + mpsc::{self, UnboundedSender}, + }, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -28,7 +31,7 @@ use crate::{ }; /// One instance per connected Gateway. -pub(super) struct GatewayHandler { +pub(crate) struct GatewayHandler { endpoint: Endpoint, gateway: Gateway, message_id: AtomicU64, @@ -38,7 +41,7 @@ pub(super) struct GatewayHandler { } impl GatewayHandler { - pub(super) fn new( + pub(crate) fn new( gateway: Gateway, tls_config: Option, pool: PgPool, @@ -193,7 +196,7 @@ impl GatewayHandler { } /// Connect to Gateway and handle its messages through gRPC. - pub(super) async fn handle_connection(&mut self) -> ! { + pub(crate) async fn handle_connection(&mut self) -> ! { let uri = self.endpoint.uri(); loop { #[cfg(not(test))] diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 8f1f5ec66..ab6d57614 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -48,7 +48,7 @@ pub(crate) mod state; // mod tests; const PEER_DISCONNECT_INTERVAL: u64 = 60; -static TONIC_SOCKET: &str = "tonic.sock"; +pub(super) static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 7219c30d1..788801106 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,6 +13,7 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ + db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; @@ -23,7 +24,7 @@ pub struct GatewayState { pub connected: bool, pub network_id: Id, pub network_name: String, - pub name: Option, + pub name: Option, // TODO: remove pub hostname: String, pub connected_at: Option, pub disconnected_at: Option, @@ -36,6 +37,19 @@ pub struct GatewayState { } impl GatewayState { + // pub(crate) fn from_gateway(gateway: &Gateway, network_name: &str) -> Self { + // Self { + // id: gateway.id, + // connected: gateway.is_connected(), + // network_id: gateway.network_id, + // network_name: network_name.to_owned(), + // name: None, // TODO: remove + // hostname: gateway.hostname.clone().unwrap_or_default(), + // connected_at: gateway.connected_at, + // disconnected_at: gateway.disconnected_at, + // } + // } + #[must_use] pub fn new>( network_id: Id, diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index eedad0a01..cee430a61 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,7 +10,7 @@ use axum::http::Uri; use defguard_common::{ VERSION, auth::claims::ClaimsType, - db::{Id, models::Settings}, + db::{ChangeNotification, Id, TriggerOperation, models::Settings}, }; use defguard_mail::Mail; use defguard_version::{ @@ -20,12 +20,13 @@ use defguard_version::{ use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; use serde::Serialize; -use sqlx::PgPool; +use sqlx::{PgPool, postgres::PgListener}; use tokio::{ sync::{ broadcast::Sender, mpsc::{self, UnboundedSender}, }, + task::{AbortHandle, JoinSet}, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -35,19 +36,20 @@ use tonic::{ Certificate, ClientTlsConfig, Endpoint, Identity, Server, ServerTlsConfig, server::Router, }, }; -use tower::ServiceBuilder; use self::{ auth::AuthServer, client_mfa::ClientMfaServer, enrollment::EnrollmentServer, - gateway::GatewayServer, interceptor::JwtInterceptor, password_reset::PasswordResetServer, - worker::WorkerServer, + gateway::handler::GatewayHandler, interceptor::JwtInterceptor, + password_reset::PasswordResetServer, worker::WorkerServer, }; -pub use crate::version::MIN_GATEWAY_VERSION; use crate::{ auth::failed_login::FailedLoginMap, db::{ AppEvent, GatewayEvent, - models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, + models::{ + enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, + gateway::Gateway, + }, }, enterprise::{ db::models::{ @@ -65,7 +67,10 @@ use crate::{ events::{BidiStreamEvent, GrpcEvent}, grpc::gateway::{client_state::ClientMap, map::GatewayMap}, server_config, - version::{IncompatibleComponents, IncompatibleProxyData, is_proxy_version_supported}, + version::{ + IncompatibleComponents, IncompatibleProxyData, MIN_GATEWAY_VERSION, + is_proxy_version_supported, + }, }; static VERSION_ZERO: Version = Version::new(0, 0, 0); @@ -546,6 +551,100 @@ async fn handle_proxy_message_loop( Ok(()) } +const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; + +/// Bi-directional gRPC stream for comminication with Defguard Gateway. +pub async fn run_grpc_gateway_stream( + pool: PgPool, + events_tx: Sender, + mail_tx: UnboundedSender, +) -> Result<(), anyhow::Error> { + let config = server_config(); + let tls_config = config.grpc_client_tls_config()?; + + let mut abort_handles = HashMap::new(); + + let mut tasks = JoinSet::new(); + // Helper closure to launch `GatewayHandler`. + let mut launch_gateway_handler = + |gateway: Gateway| -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + tls_config.clone(), + pool.clone(), + events_tx.clone(), + mail_tx.clone(), + )?; + let abort_handle = tasks.spawn(async move { + gateway_handler.handle_connection().await; + }); + Ok(abort_handle) + }; + + let gateways = Gateway::all(&pool).await?; + for gateway in gateways { + let id = gateway.id; + let abort_handle = launch_gateway_handler(gateway)?; + abort_handles.insert(id, abort_handle); + } + + // Observe gateway URL changes. + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(GATEWAY_TABLE_TRIGGER).await?; + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + match serde_json::from_str::>>(payload) { + Ok(gateway_notification) => match gateway_notification.operation { + TriggerOperation::Insert => { + if let Some(new) = gateway_notification.new { + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } + } + TriggerOperation::Update => { + if let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + { + if old.url == new.url { + debug!( + "Gateway URL didn't change. Keeping the current gateway handler" + ); + } else if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!("Aborting connection to {old}, it has changed in the database"); + abort_handle.abort(); + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + TriggerOperation::Delete => { + if let Some(old) = gateway_notification.old { + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to {old}, it has disappeard from the database" + ); + abort_handle.abort(); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + }, + Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } + } + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) +} + /// Bi-directional gRPC stream for communication with Defguard Proxy. #[instrument(skip_all)] pub async fn run_grpc_bidi_stream( From 20f1c6b05d730b0d78cc69ba752741cd99c4036d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 1 Dec 2025 11:17:37 +0100 Subject: [PATCH 10/17] Handle gateway stats --- Cargo.lock | 70 ++- crates/defguard/src/main.rs | 10 +- .../src/grpc/gateway/client_state.rs | 3 +- .../defguard_core/src/grpc/gateway/handler.rs | 277 ++++++++-- crates/defguard_core/src/grpc/gateway/mod.rs | 487 +++++++----------- .../defguard_core/src/grpc/gateway/state.rs | 2 +- crates/defguard_core/src/grpc/mod.rs | 36 +- 7 files changed, 462 insertions(+), 423 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5703ef2bc..4e4af3ad9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.47" +version = "1.2.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" +checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" dependencies = [ "find-msvc-tools", "jobserver", @@ -616,7 +616,7 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -2132,13 +2132,13 @@ dependencies = [ [[package]] name = "hostname" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a56f203cd1c76362b69e3863fd987520ac36cf70a8c92627449b2f64a8cf7d65" +checksum = "617aaa3557aef3810a6369d0a99fac8a080891b68bd9f9812a1eeda0c0730cbd" dependencies = [ "cfg-if", "libc", - "windows-link 0.1.3", + "windows-link", ] [[package]] @@ -3520,7 +3520,7 @@ dependencies = [ "libc", "redox_syscall", "smallvec", - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -3636,11 +3636,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.7.1" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" dependencies = [ "fixedbitset", + "hashbrown 0.15.5", "indexmap 2.12.1", ] @@ -3904,9 +3905,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" +checksum = "101fec8d036f8d9d4a1e8ebf90d566d1d798f3b1aa379d2576a54a0d9acea5bd" dependencies = [ "bytes", "prost-derive", @@ -3914,15 +3915,14 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" +checksum = "528a07106a21e01f4880c09818d0b7e73d0f0993536ddfff161754b5c20a086c" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", - "once_cell", "petgraph", "prettyplease", "prost", @@ -3936,9 +3936,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" +checksum = "d2d93e596a829ebe00afa41c3a056e6308d6b8a4c7d869edf184e2c91b1ba564" dependencies = [ "anyhow", "itertools 0.14.0", @@ -3949,9 +3949,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.1" +version = "0.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" +checksum = "f5d7b7346e150de32340ae3390b8b3ffa37ad93ec31fb5dad86afe817619e4e7" dependencies = [ "prost", ] @@ -4424,9 +4424,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.13.0" +version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94182ad936a0c91c324cd46c6511b9510ed16af436d7b5bab34beab0afd55f7a" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" dependencies = [ "web-time", "zeroize", @@ -5839,9 +5839,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", @@ -6409,7 +6409,7 @@ checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" dependencies = [ "windows-implement", "windows-interface", - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -6436,12 +6436,6 @@ dependencies = [ "syn", ] -[[package]] -name = "windows-link" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" - [[package]] name = "windows-link" version = "0.2.1" @@ -6454,7 +6448,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows-result", "windows-strings", ] @@ -6465,7 +6459,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6474,7 +6468,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6519,7 +6513,7 @@ version = "0.61.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" dependencies = [ - "windows-link 0.2.1", + "windows-link", ] [[package]] @@ -6559,7 +6553,7 @@ version = "0.53.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ - "windows-link 0.2.1", + "windows-link", "windows_aarch64_gnullvm 0.53.1", "windows_aarch64_msvc 0.53.1", "windows_i686_gnu 0.53.1", @@ -6801,18 +6795,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 5188cfd41..11f25325a 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -155,8 +155,10 @@ async fn main() -> Result<(), anyhow::Error> { tokio::select! { res = run_grpc_gateway_stream( pool.clone(), + client_state, wireguard_tx.clone(), - mail_tx.clone() + mail_tx.clone(), + grpc_event_tx, ) => error!("Gateway gRPC stream returned early: {res:?}"), res = run_grpc_bidi_stream( pool.clone(), @@ -168,15 +170,9 @@ async fn main() -> Result<(), anyhow::Error> { res = run_grpc_server( Arc::clone(&worker_state), pool.clone(), - Arc::clone(&gateway_state), - client_state, - wireguard_tx.clone(), - mail_tx.clone(), grpc_cert, grpc_key, failed_logins.clone(), - grpc_event_tx, - Arc::clone(&incompatible_components), ) => error!("gRPC server returned early: {res:?}"), res = run_web_server( worker_state, diff --git a/crates/defguard_core/src/grpc/gateway/client_state.rs b/crates/defguard_core/src/grpc/gateway/client_state.rs index 1bc49a404..8f0f5ecd4 100644 --- a/crates/defguard_core/src/grpc/gateway/client_state.rs +++ b/crates/defguard_core/src/grpc/gateway/client_state.rs @@ -117,7 +117,8 @@ impl ClientMap { stats: &WireguardPeerStats, ) -> Result<(), ClientMapError> { info!( - "VPN client {} with public key {public_key} connected to location {location_id} through gateway {gateway_hostname}", + "VPN client {} with public key {public_key} connected to location {location_id} \ + through Gateway {gateway_hostname}", device.name ); diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 53403fcc2..30c6ecc6f 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -1,8 +1,13 @@ use std::{ + net::SocketAddr, str::FromStr, - sync::atomic::{AtomicU64, Ordering}, + sync::{ + Arc, Mutex, + atomic::{AtomicU64, Ordering}, + }, }; +use chrono::{TimeDelta, Utc}; use defguard_common::{auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; @@ -23,10 +28,10 @@ use tonic::{ use crate::{ ClaimsType, db::{ - Device, GatewayEvent, WireguardNetwork, + Device, GatewayEvent, User, WireguardNetwork, models::{gateway::Gateway, wireguard_peer_stats::WireguardPeerStats}, }, - grpc::TEN_SECS, + grpc::{ClientMap, GrpcEvent, TEN_SECS, gateway::GrpcRequestContext}, handlers::mail::send_gateway_disconnected_email, }; @@ -36,8 +41,10 @@ pub(crate) struct GatewayHandler { gateway: Gateway, message_id: AtomicU64, pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, } impl GatewayHandler { @@ -45,8 +52,10 @@ impl GatewayHandler { gateway: Gateway, tls_config: Option, pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, ) -> Result { let endpoint = Endpoint::from_shared(gateway.url.to_string())? .http2_keep_alive_interval(TEN_SECS) @@ -63,8 +72,10 @@ impl GatewayHandler { gateway, message_id: AtomicU64::new(0), pool, + client_state, events_tx, mail_tx, + grpc_event_tx, }) } @@ -195,6 +206,79 @@ impl GatewayHandler { }; } + /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors + async fn fetch_device_from_db(&self, public_key: &str) -> Result>, Status> { + let device = Device::find_by_pubkey(&self.pool, public_key) + .await + .map_err(|err| { + error!("Failed to retrieve device with public key {public_key}: {err}",); + Status::new( + Code::Internal, + format!("Failed to retrieve device with public key {public_key}: {err}",), + ) + })?; + + Ok(device) + } + + /// Helper method to fetch `WireguardNetwork` info from DB and return appropriate errors + async fn fetch_location_from_db( + &self, + location_id: Id, + ) -> Result, Status> { + let location = match WireguardNetwork::find_by_id(&self.pool, location_id).await { + Ok(Some(location)) => location, + Ok(None) => { + error!("Location {location_id} not found"); + return Err(Status::new( + Code::Internal, + format!("Location {location_id} not found"), + )); + } + Err(err) => { + error!("Failed to retrieve location {location_id}: {err}",); + return Err(Status::new( + Code::Internal, + format!("Failed to retrieve location {location_id}: {err}",), + )); + } + }; + Ok(location) + } + + /// Helper method to fetch `User` info from DB and return appropriate errors + async fn fetch_user_from_db(&self, user_id: Id, public_key: &str) -> Result, Status> { + let user = match User::find_by_id(&self.pool, user_id).await { + Ok(Some(user)) => user, + Ok(None) => { + error!("User {user_id} assigned to device with public key {public_key} not found"); + return Err(Status::new( + Code::Internal, + format!("User assigned to device with public key {public_key} not found"), + )); + } + Err(err) => { + error!( + "Failed to retrieve user {user_id} for device with public key {public_key}: {err}", + ); + return Err(Status::new( + Code::Internal, + format!( + "Failed to retrieve user for device with public key {public_key}: {err}", + ), + )); + } + }; + + Ok(user) + } + + fn emit_event(&self, event: GrpcEvent) { + if self.grpc_event_tx.send(event).is_err() { + warn!("Failed to send gRPC event"); + } + } + /// Connect to Gateway and handle its messages through gRPC. pub(crate) async fn handle_connection(&mut self) -> ! { let uri = self.endpoint.uri(); @@ -229,11 +313,11 @@ impl GatewayHandler { 'message: loop { match resp_stream.message().await { Ok(None) => { - info!("stream was closed by the sender"); + info!("Stream was closed by the sender."); break 'message; } Ok(Some(received)) => { - info!("Received message from gateway."); + info!("Received message from Gateway."); debug!("Message from Gateway {uri}"); match received.payload { Some(core_request::Payload::ConfigRequest(config_request)) => { @@ -307,6 +391,7 @@ impl GatewayHandler { }; // tokio::spawn(super::handle_events( // network, + // self.gateway.hostname.unwrap_or_default().clone(), // tx.clone(), // self.events_tx.subscribe(), // )); @@ -314,43 +399,163 @@ impl GatewayHandler { Some(core_request::Payload::PeerStats(peer_stats)) => { if !config_sent { warn!( - "Ignoring peer statistics from {} because it didn't \ + "Ignoring peer statistics from {} because it hasn't \ authorize itself", self.gateway ); continue; } - // let public_key = peer_stats.public_key.clone(); - // let mut stats = WireguardPeerStats::from_peer_stats( - // peer_stats, - // self.gateway.network_id, - - // ); - // // Get device by public key and fill in stats.device_id - // match Device::find_by_pubkey(&self.pool, &public_key).await { - // Ok(Some(device)) => { - // stats.device_id = device.id; - // match stats.save(&self.pool).await { - // Ok(_) => { - // info!("Saved WireGuard peer stats to database.") - // } - // Err(err) => error!( - // "Failed to save WireGuard peer stats to database: \ - // {err}" - // ), - // } - // } - // Ok(None) => { - // error!("Device with public key {public_key} not found"); - // } - // Err(err) => { - // error!( - // "Failed to retrieve device with public key \ - // {public_key}: {err}", - // ); - // } - // }; + let public_key = peer_stats.public_key.clone(); + + // fetch device from DB + // TODO: fetch only when device has changed and use client state + // otherwise + let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await + else { + warn!( + "Received stats update for a device which does not \ + exist: {public_key}, skipping." + ); + continue; + }; + + // copy device ID for easier reference later + let device_id = device.id; + + // fetch user and location from DB for activity log + // TODO: cache usernames since they don't change + let Ok(user) = + self.fetch_user_from_db(device.user_id, &public_key).await + else { + continue; + }; + let Ok(location) = + self.fetch_location_from_db(self.gateway.network_id).await + else { + continue; + }; + + // Convert stats to database storage format. + let stats = WireguardPeerStats::from_peer_stats( + peer_stats, + self.gateway.network_id, + device_id, + ); + + // Only perform client state update if stats include an endpoint IP. + // Otherwise, a peer was added to the gateway interface, but hasn't + // connected yet. + if let Some(endpoint) = &stats.endpoint { + // parse client endpoint IP + let Ok(socket_addr) = endpoint.clone().parse::() + else { + error!("Failed to parse VPN client endpoint"); + continue; + }; + + // Perform client state operations in a dedicated block to drop + // mutex guard. + let disconnected_clients = { + // acquire lock on client state map + let mut client_map = self.client_state.lock().unwrap(); + + // update connected clients map + match client_map + .get_vpn_client(self.gateway.network_id, &public_key) + { + Some(client_state) => { + // update connected client state + client_state.update_client_state( + device, + socket_addr, + stats.latest_handshake, + stats.upload, + stats.download, + ); + } + None => { + // don't mark inactive peers as connected + if (Utc::now().naive_utc() - stats.latest_handshake) + < TimeDelta::seconds( + location.peer_disconnect_threshold.into(), + ) + { + // mark new VPN client as connected + if client_map + .connect_vpn_client( + self.gateway.network_id, + // Hostname is for logging only. + &self + .gateway + .hostname + .as_ref() + .cloned() + .unwrap_or_default(), + &public_key, + &device, + &user, + socket_addr, + &stats, + ) + .is_err() + { + // TODO: log message + continue; + } + + // emit connection event + let context = GrpcRequestContext::new( + user.id, + user.username.clone(), + socket_addr.ip(), + device.id, + device.name.clone(), + location.clone(), + ); + self.emit_event(GrpcEvent::ClientConnected { + context, + location: location.clone(), + device: device.clone(), + }); + } + } + } + + // disconnect inactive clients + let Ok(clients) = client_map + .disconnect_inactive_vpn_clients_for_location( + &location, + ) + else { + // TODO: log message + continue; + }; + clients + }; + + // emit client disconnect events + for (device, context) in disconnected_clients { + self.emit_event(GrpcEvent::ClientDisconnected { + context, + location: location.clone(), + device, + }); + } + } + + // Save stats to database. + let stats = match stats.save(&self.pool).await { + Ok(stats) => stats, + Err(err) => { + error!( + "Saving WireGuard peer stats to database failed: {err}" + ); + continue; + } + }; + info!("Saved WireGuard peer stats to database."); + debug!("WireGuard peer stats: {stats:?}"); } None => (), }; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ab6d57614..329a67f85 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,40 +1,30 @@ use std::{ - net::{IpAddr, SocketAddr}, - pin::Pin, - sync::{Arc, Mutex, MutexGuard}, - task::{Context, Poll}, + net::IpAddr, + sync::{Arc, Mutex}, }; -use chrono::{DateTime, TimeDelta, Utc}; +use chrono::{DateTime, Utc}; use client_state::ClientMap; use defguard_common::db::{Id, NoId}; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{ - Configuration, ConfigurationRequest, CoreResponse, Peer, PeerStats, Update, UpdateType, - core_response, update, - }, + gateway::{Configuration, Peer, PeerStats, Update, update}, }; use defguard_version::version_info_from_metadata; use semver::Version; use sqlx::PgPool; use thiserror::Error; -use tokio::{ - sync::{ - broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{self, Receiver, UnboundedSender, error::SendError}, - }, - task::JoinHandle, - time::{Duration, interval}, +use tokio::sync::{ + broadcast::{Receiver as BroadcastReceiver, Sender}, + mpsc::{self, UnboundedSender, error::SendError}, }; -use tokio_stream::Stream; -use tonic::{Code, Request, Response, Status, metadata::MetadataMap}; +use tonic::{Code, Status, metadata::MetadataMap}; use self::map::GatewayMap; use crate::{ db::{ - Device, GatewayEvent, User, + GatewayEvent, models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, }, events::{GrpcEvent, GrpcRequestContext}, @@ -164,86 +154,6 @@ impl GatewayServer { } } - pub fn get_client_state_guard(&self) -> Result, GatewayServerError> { - let client_state = self - .client_state - .lock() - .map_err(|_| GatewayServerError::ClientStateMutexError)?; - debug!("Current VPN client state map: {client_state:?}"); - Ok(client_state) - } - - fn emit_event(&self, event: GrpcEvent) -> Result<(), GatewayServerError> { - Ok(self.grpc_event_tx.send(event)?) - } - - /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors - async fn fetch_device_from_db(&self, public_key: &str) -> Result>, Status> { - let device = Device::find_by_pubkey(&self.pool, public_key) - .await - .map_err(|err| { - error!("Failed to retrieve device with public key {public_key}: {err}",); - Status::new( - Code::Internal, - format!("Failed to retrieve device with public key {public_key}: {err}",), - ) - })?; - - Ok(device) - } - - /// Helper method to fetch `WireguardNetwork` info from DB and return appropriate errors - async fn fetch_location_from_db( - &self, - location_id: Id, - ) -> Result, Status> { - let location = match WireguardNetwork::find_by_id(&self.pool, location_id).await { - Ok(Some(location)) => location, - Ok(None) => { - error!("Location {location_id} not found"); - return Err(Status::new( - Code::Internal, - format!("Location {location_id} not found"), - )); - } - Err(err) => { - error!("Failed to retrieve location {location_id}: {err}",); - return Err(Status::new( - Code::Internal, - format!("Failed to retrieve location {location_id}: {err}",), - )); - } - }; - Ok(location) - } - - /// Helper method to fetch `User` info from DB and return appropriate errors - async fn fetch_user_from_db(&self, user_id: Id, public_key: &str) -> Result, Status> { - let user = match User::find_by_id(&self.pool, user_id).await { - Ok(Some(user)) => user, - Ok(None) => { - error!("User {user_id} assigned to device with public key {public_key} not found"); - return Err(Status::new( - Code::Internal, - format!("User assigned to device with public key {public_key} not found"), - )); - } - Err(err) => { - error!( - "Failed to retrieve user {user_id} for device with public key {public_key}: {err}", - ); - return Err(Status::new( - Code::Internal, - format!( - "Failed to retrieve user for device with public key {public_key}: {err}", - ), - )); - } - }; - - Ok(user) - } - /// Utility function extracting metadata fields during gRPC communication. fn extract_metadata(metadata: &MetadataMap) -> Result { let (version, _info) = version_info_from_metadata(metadata); @@ -292,166 +202,163 @@ impl WireguardPeerStats { } } -/* - -/// Process received Gateway events -/// -/// Main gRPC server uses a shared channel for broadcasting all Gateway events, -/// so the handler must determine if an event is relevant for the network being serviced. -async fn handle_events( - mut current_network: WireguardNetwork, - tx: UnboundedSender, - mut events_rx: Receiver, -) { - info!("Starting update stream network {current_network}"); - while let Some(event) = events_rx.recv().await { - debug!("Received networking state update event: {event:?}"); - let (update_type, update) = match event { - GatewayEvent::NetworkCreated(network, _fixme) => { - if network.id != current_network.id { - continue; - } - ( - UpdateType::Create, - update::Update::Network(Configuration { - name: network.name.clone(), - prvkey: network.prvkey.clone(), - addresses: network.address.to_string(), - port: network.port as u32, - peers: Vec::new(), - }), - ) - } - GatewayEvent::NetworkModified(network, peers, _fixme) => { - if network.id != current_network.id { - continue; - } - // update stored network data - current_network = network.clone(); - ( - UpdateType::Modify, - update::Update::Network(Configuration { - name: network.name, - prvkey: network.prvkey, - addresses: network.address.to_string(), - port: network.port as u32, - peers, - }), - ) - } - GatewayEvent::NetworkDeleted(network_id, network_name) => { - if network_id != current_network.id { - continue; - } - ( - UpdateType::Delete, - update::Update::Network(Configuration { - name: network_name.to_string(), - prvkey: String::new(), - addresses: Vec::new(), - port: 0, - peers: Vec::new(), - firewall_config: None, - }), - ) - } - GatewayEvent::DeviceCreated(device) => { - // check if a peer has to be added in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(network_info) => { - if current_network.mfa_enabled && !network_info.is_authorized { - debug!( - "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", - device.device.name, current_network.name - ); - continue; - }; - let peer = Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - preshared_key: network_info.preshared_key.clone(), - keepalive_interval: Some(current_network.keepalive_interval as u32), - }; - (UpdateType::Create, update::Update::Peer(peer)) - } - None => continue, - } - } - GatewayEvent::DeviceModified(device) => { - // check if a peer has to be updated in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(network_info) => { - if current_network.mfa_enabled && !network_info.is_authorized { - debug!( - "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", - device.device.name, current_network.name - ); - continue; - }; - let peer = Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - preshared_key: network_info.preshared_key.clone(), - keepalive_interval: Some(current_network.keepalive_interval as u32), - }; - (UpdateType::Modify, update::Update::Peer(peer)) - } - None => continue, - } - } - GatewayEvent::DeviceDeleted(device) => { - // check if a peer has to be updated in the current network - match device - .network_info - .iter() - .find(|info| info.network_id == current_network.id) - { - Some(_) => ( - UpdateType::Delete, - update::Update::Peer(Peer { - pubkey: device.device.wireguard_pubkey, - allowed_ips: Vec::new(), - preshared_key: None, - keepalive_interval: None, - }), - ), - None => continue, - } - } - GatewayEvent::FirewallConfigChanged(_fixme, _) => (), - GatewayEvent::FirewallDisabled(_id) => (), - }; +// /// Process received Gateway events +// /// +// /// Main gRPC server uses a shared channel for broadcasting all Gateway events, +// /// so the handler must determine if an event is relevant for the network being serviced. +// async fn handle_events( +// mut current_network: WireguardNetwork, +// tx: UnboundedSender, +// mut events_rx: BroadcastReceiver, +// ) { +// info!("Starting update stream network {current_network}"); +// while let Some(event) = events_rx.recv().await { +// debug!("Received networking state update event: {event:?}"); +// let (update_type, update) = match event { +// GatewayEvent::NetworkCreated(network, _fixme) => { +// if network.id != current_network.id { +// continue; +// } +// ( +// UpdateType::Create, +// update::Update::Network(Configuration { +// name: network.name.clone(), +// prvkey: network.prvkey.clone(), +// addresses: network.address.to_string(), +// port: network.port as u32, +// peers: Vec::new(), +// }), +// ) +// } +// GatewayEvent::NetworkModified(network, peers, _fixme) => { +// if network.id != current_network.id { +// continue; +// } +// // update stored network data +// current_network = network.clone(); +// ( +// UpdateType::Modify, +// update::Update::Network(Configuration { +// name: network.name, +// prvkey: network.prvkey, +// addresses: network.address.to_string(), +// port: network.port as u32, +// peers, +// }), +// ) +// } +// GatewayEvent::NetworkDeleted(network_id, network_name) => { +// if network_id != current_network.id { +// continue; +// } +// ( +// UpdateType::Delete, +// update::Update::Network(Configuration { +// name: network_name.to_string(), +// prvkey: String::new(), +// addresses: Vec::new(), +// port: 0, +// peers: Vec::new(), +// firewall_config: None, +// }), +// ) +// } +// GatewayEvent::DeviceCreated(device) => { +// // check if a peer has to be added in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(network_info) => { +// if current_network.mfa_enabled && !network_info.is_authorized { +// debug!( +// "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", +// device.device.name, current_network.name +// ); +// continue; +// }; +// let peer = Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], +// preshared_key: network_info.preshared_key.clone(), +// keepalive_interval: Some(current_network.keepalive_interval as u32), +// }; +// (UpdateType::Create, update::Update::Peer(peer)) +// } +// None => continue, +// } +// } +// GatewayEvent::DeviceModified(device) => { +// // check if a peer has to be updated in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(network_info) => { +// if current_network.mfa_enabled && !network_info.is_authorized { +// debug!( +// "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", +// device.device.name, current_network.name +// ); +// continue; +// }; +// let peer = Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], +// preshared_key: network_info.preshared_key.clone(), +// keepalive_interval: Some(current_network.keepalive_interval as u32), +// }; +// (UpdateType::Modify, update::Update::Peer(peer)) +// } +// None => continue, +// } +// } +// GatewayEvent::DeviceDeleted(device) => { +// // check if a peer has to be updated in the current network +// match device +// .network_info +// .iter() +// .find(|info| info.network_id == current_network.id) +// { +// Some(_) => ( +// UpdateType::Delete, +// update::Update::Peer(Peer { +// pubkey: device.device.wireguard_pubkey, +// allowed_ips: Vec::new(), +// preshared_key: None, +// keepalive_interval: None, +// }), +// ), +// None => continue, +// } +// } +// GatewayEvent::FirewallConfigChanged(_fixme, _) => (), +// GatewayEvent::FirewallDisabled(_id) => (), +// }; - let req = CoreResponse { - id: 0, - payload: Some(core_response::Payload::Update(Update { - update_type: update_type as i32, - update: Some(update), - })), - }; - if let Err(err) = tx.send(req) { - error!( - "Failed to send network update, network {current_network}, update type: {}, error: \ - {err}", - update_type.as_str_name() - ); - break; - } - debug!( - "Network update sent for network {current_network}, update type: {}", - update_type.as_str_name() - ); - } -} -*/ +// let req = CoreResponse { +// id: 0, +// payload: Some(core_response::Payload::Update(Update { +// update_type: update_type as i32, +// update: Some(update), +// })), +// }; +// if let Err(err) = tx.send(req) { +// error!( +// "Failed to send network update, network {current_network}, update type: {}, error: \ +// {err}", +// update_type.as_str_name() +// ); +// break; +// } +// debug!( +// "Network update sent for network {current_network}, update type: {}", +// update_type.as_str_name() +// ); +// } +// } /// Helper struct for handling gateway events struct GatewayUpdatesHandler { @@ -479,7 +386,7 @@ impl GatewayUpdatesHandler { } } - /// Process incoming gateway events + /// Process incoming Gateway events /// /// Main gRPC server uses a shared channel for broadcasting all gateway events /// so the handler must determine if an event is relevant for the network being serviced @@ -797,58 +704,14 @@ impl GatewayUpdatesHandler { } } -pub struct GatewayUpdatesStream { - task_handle: JoinHandle<()>, - rx: Receiver>, - network_id: Id, - gateway_hostname: String, - gateway_state: Arc>, - pool: PgPool, -} - -impl GatewayUpdatesStream { - #[must_use] - pub fn new( - task_handle: JoinHandle<()>, - rx: Receiver>, - network_id: Id, - gateway_hostname: String, - gateway_state: Arc>, - pool: PgPool, - ) -> Self { - Self { - task_handle, - rx, - network_id, - gateway_hostname, - gateway_state, - pool, - } - } -} - -impl Stream for GatewayUpdatesStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.rx).poll_recv(cx) - } -} - -impl Drop for GatewayUpdatesStream { - fn drop(&mut self) { - info!("Client disconnected"); - // terminate update task - self.task_handle.abort(); - // update gateway state - // TODO: possibly use a oneshot channel instead - self.gateway_state - .lock() - .unwrap() - .disconnect_gateway(self.network_id, self.gateway_hostname.clone(), &self.pool) - .expect("Unable to disconnect gateway."); - } -} +// pub struct GatewayUpdatesStream { +// task_handle: JoinHandle<()>, +// rx: Receiver>, +// network_id: Id, +// gateway_hostname: String, +// gateway_state: Arc>, +// pool: PgPool, +// } // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { @@ -871,7 +734,7 @@ impl Drop for GatewayUpdatesStream { // // version = version.to_string(), info); // // let _guard = span.enter(); // loop { -// // Wait for a message or update client map at least once a mninute, if no messages are +// // Wait for a message or update client map at least once a minute, if no messages are // // received. // let stats_update = tokio::select! { // message = stream.message() => { diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 788801106..0f9b10f62 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,7 +13,7 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ - db::models::gateway::Gateway, + // db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index cee430a61..8b9e126ee 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -15,7 +15,7 @@ use defguard_common::{ use defguard_mail::Mail; use defguard_version::{ ComponentInfo, DefguardComponent, Version, client::ClientVersionInterceptor, - get_tracing_variables, server::DefguardVersionLayer, + get_tracing_variables, }; use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; @@ -65,7 +65,7 @@ use crate::{ ldap::utils::ldap_update_user_state, }, events::{BidiStreamEvent, GrpcEvent}, - grpc::gateway::{client_state::ClientMap, map::GatewayMap}, + grpc::gateway::client_state::ClientMap, server_config, version::{ IncompatibleComponents, IncompatibleProxyData, MIN_GATEWAY_VERSION, @@ -556,8 +556,10 @@ const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; /// Bi-directional gRPC stream for comminication with Defguard Gateway. pub async fn run_grpc_gateway_stream( pool: PgPool, + client_state: Arc>, events_tx: Sender, mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, ) -> Result<(), anyhow::Error> { let config = server_config(); let tls_config = config.grpc_client_tls_config()?; @@ -572,8 +574,10 @@ pub async fn run_grpc_gateway_stream( gateway, tls_config.clone(), pool.clone(), + Arc::clone(&client_state), events_tx.clone(), mail_tx.clone(), + grpc_event_tx.clone(), )?; let abort_handle = tasks.spawn(async move { gateway_handler.handle_connection().await; @@ -581,8 +585,7 @@ pub async fn run_grpc_gateway_stream( Ok(abort_handle) }; - let gateways = Gateway::all(&pool).await?; - for gateway in gateways { + for gateway in Gateway::all(&pool).await? { let id = gateway.id; let abort_handle = launch_gateway_handler(gateway)?; abort_handles.insert(id, abort_handle); @@ -757,15 +760,9 @@ pub async fn run_grpc_bidi_stream( pub async fn run_grpc_server( worker_state: Arc>, pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, grpc_cert: Option, grpc_key: Option, failed_logins: Arc>, - grpc_event_tx: UnboundedSender, - incompatible_components: Arc>, ) -> Result<(), anyhow::Error> { // Build gRPC services let server = if let (Some(cert), Some(key)) = (grpc_cert, grpc_key) { @@ -775,19 +772,7 @@ pub async fn run_grpc_server( Server::builder() }; - let router = build_grpc_service_router( - server, - pool, - worker_state, - // gateway_state, - // client_state, - // wireguard_tx, - // mail_tx, - failed_logins, - // grpc_event_tx, - // incompatible_components, - ) - .await?; + let router = build_grpc_service_router(server, pool, worker_state, failed_logins).await?; // Run gRPC server let addr = SocketAddr::new( @@ -806,12 +791,7 @@ pub async fn build_grpc_service_router( server: Server, pool: PgPool, worker_state: Arc>, - // gateway_state: Arc>, - // client_state: Arc>, - // wireguard_tx: Sender, - // mail_tx: UnboundedSender, failed_logins: Arc>, - // grpc_event_tx: UnboundedSender, // incompatible_components: Arc>, ) -> Result { let auth_service = AuthServiceServer::new(AuthServer::new(pool.clone(), failed_logins)); From febd7e628f3e4d679657642a5a4732e7874a5b0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Mon, 1 Dec 2025 13:55:58 +0100 Subject: [PATCH 11/17] Unclog GatewayUpdatesHandler --- .../defguard_core/src/grpc/gateway/handler.rs | 52 +- crates/defguard_core/src/grpc/gateway/mod.rs | 655 +++++------------- 2 files changed, 219 insertions(+), 488 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 30c6ecc6f..e5527ee16 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -80,7 +80,10 @@ impl GatewayHandler { } /// Send network and VPN configuration to Gateway. - async fn send_configuration(&self, tx: &UnboundedSender) -> Result<(), Status> { + async fn send_configuration( + &self, + tx: &UnboundedSender, + ) -> Result, Status> { debug!("Sending configuration to Gateway"); let network_id = self.gateway.network_id; // let hostname = Self::get_gateway_hostname(request.metadata())?; @@ -146,7 +149,7 @@ impl GatewayHandler { match tx.send(req) { Ok(()) => { info!("Configuration sent to {}, network {network}", self.gateway); - Ok(()) + Ok(network) } Err(err) => { error!("Failed to send configuration sent to {}", self.gateway); @@ -360,13 +363,32 @@ impl GatewayHandler { // Send network configuration to Gateway. match self.send_configuration(&tx).await { - Ok(()) => { + Ok(network) => { info!("Sent configuration to {}", self.gateway); config_sent = true; let _ = self .gateway .touch_connected(&self.pool, config_request.hostname) .await; + let guh = super::GatewayUpdatesHandler::new( + self.gateway.network_id, + network, + self + .gateway + .hostname + .as_ref() + .cloned() + .unwrap_or_default() + .clone(), + self.events_tx.subscribe(), + tx.clone(), + ); + // tokio::spawn(super::handle_events( + // network, + // // self.gateway.hostname.unwrap_or_default().clone(), + // tx.clone(), + // self.events_tx.subscribe(), + // )); } Err(err) => { error!( @@ -375,26 +397,6 @@ impl GatewayHandler { ); } } - - // Start observing configuration changes. - let Ok(Some(network)) = WireguardNetwork::find_by_id( - &self.pool, - self.gateway.network_id, - ) - .await - else { - error!( - "Failed to fetch network ID {} from the database", - self.gateway.network_id - ); - continue; - }; - // tokio::spawn(super::handle_events( - // network, - // self.gateway.hostname.unwrap_or_default().clone(), - // tx.clone(), - // self.events_tx.subscribe(), - // )); } Some(core_request::Payload::PeerStats(peer_stats)) => { if !config_sent { @@ -408,7 +410,7 @@ impl GatewayHandler { let public_key = peer_stats.public_key.clone(); - // fetch device from DB + // Fetch device from database. // TODO: fetch only when device has changed and use client state // otherwise let Ok(Some(device)) = self.fetch_device_from_db(&public_key).await @@ -561,7 +563,7 @@ impl GatewayHandler { }; } Err(err) => { - error!("Disconnected from gateway at {uri}, error: {err}"); + error!("Disconnected from Gateway at {uri}, error: {err}"); // Important: call this funtion before setting disconnection time. self.send_disconnect_notification().await; let _ = self.gateway.touch_disconnected(&self.pool).await; diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 329a67f85..ae2ff3ce5 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -9,7 +9,7 @@ use defguard_common::db::{Id, NoId}; use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, - gateway::{Configuration, Peer, PeerStats, Update, update}, + gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; use defguard_version::version_info_from_metadata; use semver::Version; @@ -202,163 +202,164 @@ impl WireguardPeerStats { } } -// /// Process received Gateway events -// /// -// /// Main gRPC server uses a shared channel for broadcasting all Gateway events, -// /// so the handler must determine if an event is relevant for the network being serviced. -// async fn handle_events( -// mut current_network: WireguardNetwork, -// tx: UnboundedSender, -// mut events_rx: BroadcastReceiver, -// ) { -// info!("Starting update stream network {current_network}"); -// while let Some(event) = events_rx.recv().await { -// debug!("Received networking state update event: {event:?}"); -// let (update_type, update) = match event { -// GatewayEvent::NetworkCreated(network, _fixme) => { -// if network.id != current_network.id { -// continue; -// } -// ( -// UpdateType::Create, -// update::Update::Network(Configuration { -// name: network.name.clone(), -// prvkey: network.prvkey.clone(), -// addresses: network.address.to_string(), -// port: network.port as u32, -// peers: Vec::new(), -// }), -// ) -// } -// GatewayEvent::NetworkModified(network, peers, _fixme) => { -// if network.id != current_network.id { -// continue; -// } -// // update stored network data -// current_network = network.clone(); -// ( -// UpdateType::Modify, -// update::Update::Network(Configuration { -// name: network.name, -// prvkey: network.prvkey, -// addresses: network.address.to_string(), -// port: network.port as u32, -// peers, -// }), -// ) -// } -// GatewayEvent::NetworkDeleted(network_id, network_name) => { -// if network_id != current_network.id { -// continue; -// } -// ( -// UpdateType::Delete, -// update::Update::Network(Configuration { -// name: network_name.to_string(), -// prvkey: String::new(), -// addresses: Vec::new(), -// port: 0, -// peers: Vec::new(), -// firewall_config: None, -// }), -// ) -// } -// GatewayEvent::DeviceCreated(device) => { -// // check if a peer has to be added in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(network_info) => { -// if current_network.mfa_enabled && !network_info.is_authorized { -// debug!( -// "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", -// device.device.name, current_network.name -// ); -// continue; -// }; -// let peer = Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], -// preshared_key: network_info.preshared_key.clone(), -// keepalive_interval: Some(current_network.keepalive_interval as u32), -// }; -// (UpdateType::Create, update::Update::Peer(peer)) -// } -// None => continue, -// } -// } -// GatewayEvent::DeviceModified(device) => { -// // check if a peer has to be updated in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(network_info) => { -// if current_network.mfa_enabled && !network_info.is_authorized { -// debug!( -// "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", -// device.device.name, current_network.name -// ); -// continue; -// }; -// let peer = Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: vec![network_info.device_wireguard_ip.to_string()], -// preshared_key: network_info.preshared_key.clone(), -// keepalive_interval: Some(current_network.keepalive_interval as u32), -// }; -// (UpdateType::Modify, update::Update::Peer(peer)) -// } -// None => continue, -// } -// } -// GatewayEvent::DeviceDeleted(device) => { -// // check if a peer has to be updated in the current network -// match device -// .network_info -// .iter() -// .find(|info| info.network_id == current_network.id) -// { -// Some(_) => ( -// UpdateType::Delete, -// update::Update::Peer(Peer { -// pubkey: device.device.wireguard_pubkey, -// allowed_ips: Vec::new(), -// preshared_key: None, -// keepalive_interval: None, -// }), -// ), -// None => continue, -// } -// } -// GatewayEvent::FirewallConfigChanged(_fixme, _) => (), -// GatewayEvent::FirewallDisabled(_id) => (), -// }; - -// let req = CoreResponse { -// id: 0, -// payload: Some(core_response::Payload::Update(Update { -// update_type: update_type as i32, -// update: Some(update), -// })), -// }; -// if let Err(err) = tx.send(req) { -// error!( -// "Failed to send network update, network {current_network}, update type: {}, error: \ -// {err}", -// update_type.as_str_name() -// ); -// break; -// } -// debug!( -// "Network update sent for network {current_network}, update type: {}", -// update_type.as_str_name() -// ); -// } -// } +/// Process received Gateway events +/// +/// Main gRPC server uses a shared channel for broadcasting all Gateway events, +/// so the handler must determine if an event is relevant for the network being serviced. +async fn handle_events( + mut current_network: WireguardNetwork, + // gateway_hostname: String, + tx: UnboundedSender, + mut events_rx: BroadcastReceiver, +) { + info!("Starting update stream network {current_network}"); + // while let Some(event) = events_rx.recv().await { + // debug!("Received networking state update event: {event:?}"); + // let (update_type, update) = match event { + // GatewayEvent::NetworkCreated(network, _fixme) => { + // if network.id != current_network.id { + // continue; + // } + // ( + // UpdateType::Create, + // update::Update::Network(Configuration { + // name: network.name.clone(), + // prvkey: network.prvkey.clone(), + // addresses: network.address.to_string(), + // port: network.port as u32, + // peers: Vec::new(), + // }), + // ) + // } + // GatewayEvent::NetworkModified(network, peers, _fixme) => { + // if network.id != current_network.id { + // continue; + // } + // // update stored network data + // current_network = network.clone(); + // ( + // UpdateType::Modify, + // update::Update::Network(Configuration { + // name: network.name, + // prvkey: network.prvkey, + // addresses: network.address.to_string(), + // port: network.port as u32, + // peers, + // }), + // ) + // } + // GatewayEvent::NetworkDeleted(network_id, network_name) => { + // if network_id != current_network.id { + // continue; + // } + // ( + // UpdateType::Delete, + // update::Update::Network(Configuration { + // name: network_name.to_string(), + // prvkey: String::new(), + // addresses: Vec::new(), + // port: 0, + // peers: Vec::new(), + // firewall_config: None, + // }), + // ) + // } + // GatewayEvent::DeviceCreated(device) => { + // // check if a peer has to be added in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(network_info) => { + // if current_network.mfa_enabled && !network_info.is_authorized { + // debug!( + // "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + // device.device.name, current_network.name + // ); + // continue; + // }; + // let peer = Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + // preshared_key: network_info.preshared_key.clone(), + // keepalive_interval: Some(current_network.keepalive_interval as u32), + // }; + // (UpdateType::Create, update::Update::Peer(peer)) + // } + // None => continue, + // } + // } + // GatewayEvent::DeviceModified(device) => { + // // check if a peer has to be updated in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(network_info) => { + // if current_network.mfa_enabled && !network_info.is_authorized { + // debug!( + // "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + // device.device.name, current_network.name + // ); + // continue; + // }; + // let peer = Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], + // preshared_key: network_info.preshared_key.clone(), + // keepalive_interval: Some(current_network.keepalive_interval as u32), + // }; + // (UpdateType::Modify, update::Update::Peer(peer)) + // } + // None => continue, + // } + // } + // GatewayEvent::DeviceDeleted(device) => { + // // check if a peer has to be updated in the current network + // match device + // .network_info + // .iter() + // .find(|info| info.network_id == current_network.id) + // { + // Some(_) => ( + // UpdateType::Delete, + // update::Update::Peer(Peer { + // pubkey: device.device.wireguard_pubkey, + // allowed_ips: Vec::new(), + // preshared_key: None, + // keepalive_interval: None, + // }), + // ), + // None => continue, + // } + // } + // GatewayEvent::FirewallConfigChanged(_fixme, _) => (), + // GatewayEvent::FirewallDisabled(_id) => (), + // }; + + // let req = CoreResponse { + // id: 0, + // payload: Some(core_response::Payload::Update(Update { + // update_type: update_type as i32, + // update: Some(update), + // })), + // }; + // if let Err(err) = tx.send(req) { + // error!( + // "Failed to send network update, network {current_network}, update type: {}, error: \ + // {err}", + // update_type.as_str_name() + // ); + // break; + // } + // debug!( + // "Network update sent for network {current_network}, update type: {}", + // update_type.as_str_name() + // ); + // } +} /// Helper struct for handling gateway events struct GatewayUpdatesHandler { @@ -366,7 +367,7 @@ struct GatewayUpdatesHandler { network: WireguardNetwork, gateway_hostname: String, events_rx: BroadcastReceiver, - tx: mpsc::Sender>, + tx: UnboundedSender, } impl GatewayUpdatesHandler { @@ -375,7 +376,7 @@ impl GatewayUpdatesHandler { network: WireguardNetwork, gateway_hostname: String, events_rx: BroadcastReceiver, - tx: mpsc::Sender>, + tx: UnboundedSender, ) -> Self { Self { network_id, @@ -545,9 +546,9 @@ impl GatewayUpdatesHandler { update_type: i32, ) -> Result<(), Status> { debug!("Sending network update for network {network}"); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type, update: Some(update::Update::Network(Configuration { name: network.name.clone(), @@ -557,9 +558,8 @@ impl GatewayUpdatesHandler { peers, firewall_config, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send network update, network {network}, update type: {update_type} ({}), error: {err}", if update_type == 0 { "CREATE" } else { "MODIFY" }, @@ -577,9 +577,9 @@ impl GatewayUpdatesHandler { "Sending network delete command for network {}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::Network(Configuration { name: network_name.to_string(), @@ -589,9 +589,8 @@ impl GatewayUpdatesHandler { peers: Vec::new(), firewall_config: None, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send network update, network {}, update type: 2 (DELETE), error: {err}", self.network, @@ -606,14 +605,13 @@ impl GatewayUpdatesHandler { /// Send update peer command to gateway async fn send_peer_update(&self, peer: Peer, update_type: i32) -> Result<(), Status> { debug!("Sending peer update for network {}", self.network); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type, update: Some(update::Update::Peer(peer)), - })) - .await - { + })), + }) { let msg = format!( "Failed to send peer update for network {}, update type: {update_type} ({}), error: {err}", self.network, @@ -629,9 +627,9 @@ impl GatewayUpdatesHandler { /// Send delete peer command to gateway async fn send_peer_delete(&self, peer_pubkey: &str) -> Result<(), Status> { debug!("Sending peer delete for network {}", self.network); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::Peer(Peer { pubkey: peer_pubkey.into(), @@ -639,9 +637,8 @@ impl GatewayUpdatesHandler { preshared_key: None, keepalive_interval: None, })), - })) - .await - { + })), + }) { let msg = format!( "Failed to send peer update for network {}, peer {peer_pubkey}, update type: 2 (DELETE), error: {err}", self.network, @@ -659,14 +656,13 @@ impl GatewayUpdatesHandler { "Sending firewall config update for network {} with config {firewall_config:?}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 1, update: Some(update::Update::FirewallConfig(firewall_config)), - })) - .await - { + })), + }) { let msg = format!( "Failed to send firewall config update for network {}, error: {err}", self.network, @@ -684,14 +680,13 @@ impl GatewayUpdatesHandler { "Sending firewall disable command for network {}", self.network ); - if let Err(err) = self - .tx - .send(Ok(Update { + if let Err(err) = self.tx.send(CoreResponse { + id: 0, + payload: Some(core_response::Payload::Update(Update { update_type: 2, update: Some(update::Update::DisableFirewall(())), - })) - .await - { + })), + }) { let msg = format!( "Failed to send firewall disable command for network {}, error: {err}", self.network, @@ -716,273 +711,7 @@ impl GatewayUpdatesHandler { // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { // type UpdatesStream = GatewayUpdatesStream; - -// /// Retrieve stats from gateway and save it to database -// async fn stats( -// &self, -// request: Request>, -// ) -> Result, Status> { -// let GatewayMetadata { -// network_id, -// hostname, -// .. -// } = Self::extract_metadata(request.metadata())?; -// let mut stream = request.into_inner(); -// let mut disconnect_timer = interval(Duration::from_secs(PEER_DISCONNECT_INTERVAL)); -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_stats", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); -// loop { -// // Wait for a message or update client map at least once a minute, if no messages are -// // received. -// let stats_update = tokio::select! { -// message = stream.message() => { -// match message? { -// Some(update) => update, -// None => break, // Stream ended -// } -// } -// _ = disconnect_timer.tick() => { -// debug!("No stats updates received in last {PEER_DISCONNECT_INTERVAL} seconds. \ -// Updating disconnected VPN clients"); -// // fetch location to get current peer disconnect threshold -// let location = self.fetch_location_from_db(network_id).await?; - -// // perform client state operations in a dedicated block to drop mutex guard -// let disconnected_clients = { -// // acquire lock on client state map -// let mut client_map = self.get_client_state_guard()?; - -// // disconnect inactive clients -// client_map.disconnect_inactive_vpn_clients_for_location(&location -// )? -// }; - -// // emit client disconnect events -// for (device, context) in disconnected_clients { -// self.emit_event(GrpcEvent::ClientDisconnected { -// context, -// location: location.clone(), -// device, -// })?; -// }; -// continue; -// } -// }; - -// debug!("Received stats message: {stats_update:?}"); -// let Some(stats_update::Payload::PeerStats(peer_stats)) = stats_update.payload else { -// debug!("Received stats message is empty, skipping."); -// continue; -// }; -// let public_key = peer_stats.public_key.clone(); - -// // fetch device from DB -// // TODO: fetch only when device has changed and use client state otherwise -// let device = match self.fetch_device_from_db(&public_key).await? { -// Some(device) => device, -// None => { -// warn!( -// "Received stats update for a device which does not exist: {public_key}, skipping." -// ); -// continue; -// } -// }; - -// // copy device ID for easier reference later -// let device_id = device.id; - -// // fetch user and location from DB for activity log -// // TODO: cache usernames since they don't change -// let user = self.fetch_user_from_db(device.user_id, &public_key).await?; -// let location = self.fetch_location_from_db(network_id).await?; - -// // convert stats to DB storage format -// let stats = WireguardPeerStats::from_peer_stats(peer_stats, network_id, device_id); - -// // only perform client state update if stats include an endpoint IP -// // otherwise a peer was added to the gateway interface -// // but has not connected yet -// if let Some(endpoint) = &stats.endpoint { -// // parse client endpoint IP -// let socket_addr: SocketAddr = endpoint.clone().parse().map_err(|err| { -// error!("Failed to parse VPN client endpoint: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to parse VPN client endpoint: {err}"), -// ) -// })?; - -// // perform client state operations in a dedicated block to drop mutex guard -// let disconnected_clients = { -// // acquire lock on client state map -// let mut client_map = self.get_client_state_guard()?; - -// // update connected clients map -// match client_map.get_vpn_client(network_id, &public_key) { -// Some(client_state) => { -// // update connected client state -// client_state.update_client_state( -// device, -// socket_addr, -// stats.latest_handshake, -// stats.upload, -// stats.download, -// ); -// } -// None => { -// // don't mark inactive peers as connected -// if (Utc::now().naive_utc() - stats.latest_handshake) -// < TimeDelta::seconds(location.peer_disconnect_threshold.into()) -// { -// // mark new VPN client as connected -// client_map.connect_vpn_client( -// network_id, -// &hostname, -// &public_key, -// &device, -// &user, -// socket_addr, -// &stats, -// )?; - -// // emit connection event -// let context = GrpcRequestContext::new( -// user.id, -// user.username.clone(), -// socket_addr.ip(), -// device.id, -// device.name.clone(), -// location.clone(), -// ); -// self.emit_event(GrpcEvent::ClientConnected { -// context, -// location: location.clone(), -// device: device.clone(), -// })?; -// } -// } -// } - -// // disconnect inactive clients -// client_map.disconnect_inactive_vpn_clients_for_location(&location)? -// }; - -// // emit client disconnect events -// for (device, context) in disconnected_clients { -// self.emit_event(GrpcEvent::ClientDisconnected { -// context, -// location: location.clone(), -// device, -// })?; -// } -// } - -// // Save stats to db -// let stats = match stats.save(&self.pool).await { -// Ok(stats) => stats, -// Err(err) => { -// error!("Saving WireGuard peer stats to db failed: {err}"); -// return Err(Status::new( -// Code::Internal, -// format!("Saving WireGuard peer stats to db failed: {err}"), -// )); -// } -// }; -// info!("Saved WireGuard peer stats to db."); -// debug!("WireGuard peer stats: {stats:?}"); -// } - -// Ok(Response::new(())) -// } - -// async fn config( -// &self, -// request: Request, -// ) -> Result, Status> { -// debug!("Sending configuration to gateway client."); -// let GatewayMetadata { -// network_id, -// hostname, -// version, -// .. -// // info, -// } = Self::extract_metadata(request.metadata())?; -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_config", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); - -// let mut conn = self.pool.acquire().await.map_err(|e| { -// error!("Failed to acquire DB connection: {e}"); -// Status::new( -// Code::Internal, -// "Failed to acquire DB connection".to_string(), -// ) -// })?; - -// let mut network = WireguardNetwork::find_by_id(&mut *conn, network_id) -// .await -// .map_err(|e| { -// error!("Network {network_id} not found"); -// Status::new(Code::Internal, format!("Failed to retrieve network: {e}")) -// })? -// .ok_or_else(|| { -// Status::new( -// Code::Internal, -// format!("Network with id {network_id} not found"), -// ) -// })?; - -// debug!("Sending configuration to gateway client, network {network}."); - -// // store connected gateway in memory -// { -// let mut state = self.gateway_state.lock().unwrap(); -// state.add_gateway( -// network_id, -// &network.name, -// hostname, -// request.into_inner().name, -// self.mail_tx.clone(), -// version, -// ); -// } - -// network.connected_at = Some(Utc::now().naive_utc()); -// if let Err(err) = network.save(&mut *conn).await { -// error!("Failed to save updated network {network_id} in the database, status: {err}"); -// } - -// let peers = network.get_peers(&mut *conn).await.map_err(|error| { -// error!("Failed to fetch peers from the database for network {network_id}: {error}",); -// Status::new( -// Code::Internal, -// format!("Failed to retrieve peers from the database for network: {network_id}"), -// ) -// })?; -// let maybe_firewall_config = -// network -// .try_get_firewall_config(&mut conn) -// .await -// .map_err(|err| { -// error!("Failed to generate firewall config for network {network_id}: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to generate firewall config for network: {network_id}"), -// ) -// })?; - -// info!("Configuration sent to gateway client, network {network}."); - -// Ok(Response::new(gen_config( -// &network, -// peers, -// maybe_firewall_config, -// ))) -// } - +// // async fn updates(&self, request: Request<()>) -> Result, Status> { // let GatewayMetadata { // network_id, From 0a759cfe208061c191a44e0bf92aa7a3cc1f8646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 2 Dec 2025 14:15:58 +0100 Subject: [PATCH 12/17] Gateway metadata --- Cargo.lock | 28 +- .../defguard_core/src/grpc/gateway/handler.rs | 87 +++++- crates/defguard_core/src/grpc/gateway/mod.rs | 266 +----------------- 3 files changed, 96 insertions(+), 285 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e4af3ad9..009d6fab3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3636,12 +3636,11 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.8.3" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" dependencies = [ "fixedbitset", - "hashbrown 0.15.5", "indexmap 2.12.1", ] @@ -3905,9 +3904,9 @@ dependencies = [ [[package]] name = "prost" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "101fec8d036f8d9d4a1e8ebf90d566d1d798f3b1aa379d2576a54a0d9acea5bd" +checksum = "7231bd9b3d3d33c86b58adbac74b5ec0ad9f496b19d22801d773636feaa95f3d" dependencies = [ "bytes", "prost-derive", @@ -3915,14 +3914,15 @@ dependencies = [ [[package]] name = "prost-build" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "528a07106a21e01f4880c09818d0b7e73d0f0993536ddfff161754b5c20a086c" +checksum = "ac6c3320f9abac597dcbc668774ef006702672474aad53c6d596b62e487b40b1" dependencies = [ "heck", "itertools 0.14.0", "log", "multimap", + "once_cell", "petgraph", "prettyplease", "prost", @@ -3936,9 +3936,9 @@ dependencies = [ [[package]] name = "prost-derive" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2d93e596a829ebe00afa41c3a056e6308d6b8a4c7d869edf184e2c91b1ba564" +checksum = "9120690fafc389a67ba3803df527d0ec9cbbc9cc45e4cc20b332996dfb672425" dependencies = [ "anyhow", "itertools 0.14.0", @@ -3949,9 +3949,9 @@ dependencies = [ [[package]] name = "prost-types" -version = "0.14.2" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5d7b7346e150de32340ae3390b8b3ffa37ad93ec31fb5dad86afe817619e4e7" +checksum = "b9b4db3d6da204ed77bb26ba83b6122a73aeb2e87e25fbf7ad2e84c4ccbf8f72" dependencies = [ "prost", ] @@ -6062,13 +6062,13 @@ checksum = "e2eebbbfe4093922c2b6734d7c679ebfebd704a0d7e56dfcb0d05818ce28977d" [[package]] name = "uuid" -version = "1.18.1" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.4", "js-sys", - "serde", + "serde_core", "wasm-bindgen", ] diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index e5527ee16..bc3213522 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -11,6 +11,8 @@ use chrono::{TimeDelta, Utc}; use defguard_common::{auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; +use defguard_version::version_info_from_metadata; +use semver::Version; use sqlx::PgPool; use tokio::{ sync::{ @@ -22,6 +24,7 @@ use tokio::{ use tokio_stream::wrappers::UnboundedReceiverStream; use tonic::{ Code, Status, + metadata::MetadataMap, transport::{ClientTlsConfig, Endpoint}, }; @@ -47,6 +50,14 @@ pub(crate) struct GatewayHandler { grpc_event_tx: UnboundedSender, } +/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. +struct GatewayMetadata { + network_id: Id, + hostname: String, + version: Version, + // info: String, +} + impl GatewayHandler { pub(crate) fn new( gateway: Gateway, @@ -79,6 +90,57 @@ impl GatewayHandler { }) } + fn get_network_id(metadata: &MetadataMap) -> Result { + match Self::get_network_id_from_metadata(metadata) { + Some(m) => Ok(m), + None => Err(Status::new( + Code::Internal, + "Network ID was not found in metadata", + )), + } + } + + // parse network id from gateway request metadata from intercepted information from JWT token + fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { + if let Some(ascii_value) = metadata.get("gateway_network_id") { + if let Ok(slice) = ascii_value.clone().to_str() { + if let Ok(id) = slice.parse::() { + return Some(id); + } + } + } + None + } + + // extract gateway hostname from request headers + fn get_gateway_hostname(metadata: &MetadataMap) -> Result { + match metadata.get("hostname") { + Some(ascii_value) => { + let hostname = ascii_value.to_str().map_err(|_| { + Status::new( + Code::Internal, + "Failed to parse gateway hostname from request metadata", + ) + })?; + Ok(hostname.into()) + } + None => Err(Status::new( + Code::Internal, + "Gateway hostname not found in request metadata", + )), + } + } + + /// Utility function extracting metadata fields during gRPC communication. + fn extract_metadata(metadata: &MetadataMap) -> Result { + let (version, _info) = version_info_from_metadata(metadata); + Ok(GatewayMetadata { + network_id: Self::get_network_id(metadata)?, + hostname: Self::get_gateway_hostname(metadata)?, + version, + }) + } + /// Send network and VPN configuration to Gateway. async fn send_configuration( &self, @@ -86,7 +148,6 @@ impl GatewayHandler { ) -> Result, Status> { debug!("Sending configuration to Gateway"); let network_id = self.gateway.network_id; - // let hostname = Self::get_gateway_hostname(request.metadata())?; let mut conn = self.pool.acquire().await.map_err(|err| { error!("Failed to acquire DB connection: {err}"); @@ -310,6 +371,15 @@ impl GatewayHandler { }; info!("Connected to Defguard Gateway {uri}"); + let Ok(GatewayMetadata { + network_id, + hostname, + .. + // info, + }) = Self::extract_metadata(response.metadata()) else { + continue; + }; + let mut resp_stream = response.into_inner(); let mut config_sent = false; @@ -322,6 +392,7 @@ impl GatewayHandler { Ok(Some(received)) => { info!("Received message from Gateway."); debug!("Message from Gateway {uri}"); + match received.payload { Some(core_request::Payload::ConfigRequest(config_request)) => { if config_sent { @@ -370,11 +441,10 @@ impl GatewayHandler { .gateway .touch_connected(&self.pool, config_request.hostname) .await; - let guh = super::GatewayUpdatesHandler::new( + let mut guh = super::GatewayUpdatesHandler::new( self.gateway.network_id, network, - self - .gateway + self.gateway .hostname .as_ref() .cloned() @@ -383,12 +453,9 @@ impl GatewayHandler { self.events_tx.subscribe(), tx.clone(), ); - // tokio::spawn(super::handle_events( - // network, - // // self.gateway.hostname.unwrap_or_default().clone(), - // tx.clone(), - // self.events_tx.subscribe(), - // )); + tokio::spawn(async move { + guh.run().await; + }); } Err(err) => { error!( diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index ae2ff3ce5..7be29eac1 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -11,15 +11,13 @@ use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use defguard_version::version_info_from_metadata; -use semver::Version; use sqlx::PgPool; use thiserror::Error; use tokio::sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{self, UnboundedSender, error::SendError}, + mpsc::{UnboundedSender, error::SendError}, }; -use tonic::{Code, Status, metadata::MetadataMap}; +use tonic::{Code, Status}; use self::map::GatewayMap; use crate::{ @@ -84,87 +82,6 @@ pub struct GatewayServer { grpc_event_tx: UnboundedSender, } -/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. -struct GatewayMetadata { - network_id: Id, - hostname: String, - version: Version, - // info: String, -} - -impl GatewayServer { - /// Create new gateway server instance - #[must_use] - pub fn new( - pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, - ) -> Self { - Self { - pool, - gateway_state, - client_state, - wireguard_tx, - mail_tx, - grpc_event_tx, - } - } - - fn get_network_id(metadata: &MetadataMap) -> Result { - match Self::get_network_id_from_metadata(metadata) { - Some(m) => Ok(m), - None => Err(Status::new( - Code::Internal, - "Network ID was not found in metadata", - )), - } - } - - // parse network id from gateway request metadata from intercepted information from JWT token - fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { - if let Some(ascii_value) = metadata.get("gateway_network_id") { - if let Ok(slice) = ascii_value.clone().to_str() { - if let Ok(id) = slice.parse::() { - return Some(id); - } - } - } - None - } - - // extract gateway hostname from request headers - fn get_gateway_hostname(metadata: &MetadataMap) -> Result { - match metadata.get("hostname") { - Some(ascii_value) => { - let hostname = ascii_value.to_str().map_err(|_| { - Status::new( - Code::Internal, - "Failed to parse gateway hostname from request metadata", - ) - })?; - Ok(hostname.into()) - } - None => Err(Status::new( - Code::Internal, - "Gateway hostname not found in request metadata", - )), - } - } - - /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Result { - let (version, _info) = version_info_from_metadata(metadata); - Ok(GatewayMetadata { - network_id: Self::get_network_id(metadata)?, - hostname: Self::get_gateway_hostname(metadata)?, - version, - }) - } -} - fn gen_config( network: &WireguardNetwork, peers: Vec, @@ -202,166 +119,7 @@ impl WireguardPeerStats { } } -/// Process received Gateway events -/// -/// Main gRPC server uses a shared channel for broadcasting all Gateway events, -/// so the handler must determine if an event is relevant for the network being serviced. -async fn handle_events( - mut current_network: WireguardNetwork, - // gateway_hostname: String, - tx: UnboundedSender, - mut events_rx: BroadcastReceiver, -) { - info!("Starting update stream network {current_network}"); - // while let Some(event) = events_rx.recv().await { - // debug!("Received networking state update event: {event:?}"); - // let (update_type, update) = match event { - // GatewayEvent::NetworkCreated(network, _fixme) => { - // if network.id != current_network.id { - // continue; - // } - // ( - // UpdateType::Create, - // update::Update::Network(Configuration { - // name: network.name.clone(), - // prvkey: network.prvkey.clone(), - // addresses: network.address.to_string(), - // port: network.port as u32, - // peers: Vec::new(), - // }), - // ) - // } - // GatewayEvent::NetworkModified(network, peers, _fixme) => { - // if network.id != current_network.id { - // continue; - // } - // // update stored network data - // current_network = network.clone(); - // ( - // UpdateType::Modify, - // update::Update::Network(Configuration { - // name: network.name, - // prvkey: network.prvkey, - // addresses: network.address.to_string(), - // port: network.port as u32, - // peers, - // }), - // ) - // } - // GatewayEvent::NetworkDeleted(network_id, network_name) => { - // if network_id != current_network.id { - // continue; - // } - // ( - // UpdateType::Delete, - // update::Update::Network(Configuration { - // name: network_name.to_string(), - // prvkey: String::new(), - // addresses: Vec::new(), - // port: 0, - // peers: Vec::new(), - // firewall_config: None, - // }), - // ) - // } - // GatewayEvent::DeviceCreated(device) => { - // // check if a peer has to be added in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(network_info) => { - // if current_network.mfa_enabled && !network_info.is_authorized { - // debug!( - // "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", - // device.device.name, current_network.name - // ); - // continue; - // }; - // let peer = Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - // preshared_key: network_info.preshared_key.clone(), - // keepalive_interval: Some(current_network.keepalive_interval as u32), - // }; - // (UpdateType::Create, update::Update::Peer(peer)) - // } - // None => continue, - // } - // } - // GatewayEvent::DeviceModified(device) => { - // // check if a peer has to be updated in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(network_info) => { - // if current_network.mfa_enabled && !network_info.is_authorized { - // debug!( - // "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", - // device.device.name, current_network.name - // ); - // continue; - // }; - // let peer = Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: vec![network_info.device_wireguard_ip.to_string()], - // preshared_key: network_info.preshared_key.clone(), - // keepalive_interval: Some(current_network.keepalive_interval as u32), - // }; - // (UpdateType::Modify, update::Update::Peer(peer)) - // } - // None => continue, - // } - // } - // GatewayEvent::DeviceDeleted(device) => { - // // check if a peer has to be updated in the current network - // match device - // .network_info - // .iter() - // .find(|info| info.network_id == current_network.id) - // { - // Some(_) => ( - // UpdateType::Delete, - // update::Update::Peer(Peer { - // pubkey: device.device.wireguard_pubkey, - // allowed_ips: Vec::new(), - // preshared_key: None, - // keepalive_interval: None, - // }), - // ), - // None => continue, - // } - // } - // GatewayEvent::FirewallConfigChanged(_fixme, _) => (), - // GatewayEvent::FirewallDisabled(_id) => (), - // }; - - // let req = CoreResponse { - // id: 0, - // payload: Some(core_response::Payload::Update(Update { - // update_type: update_type as i32, - // update: Some(update), - // })), - // }; - // if let Err(err) = tx.send(req) { - // error!( - // "Failed to send network update, network {current_network}, update type: {}, error: \ - // {err}", - // update_type.as_str_name() - // ); - // break; - // } - // debug!( - // "Network update sent for network {current_network}, update type: {}", - // update_type.as_str_name() - // ); - // } -} - -/// Helper struct for handling gateway events +/// Helper struct for handling gateway events. struct GatewayUpdatesHandler { network_id: Id, network: WireguardNetwork, @@ -561,7 +319,8 @@ impl GatewayUpdatesHandler { })), }) { let msg = format!( - "Failed to send network update, network {network}, update type: {update_type} ({}), error: {err}", + "Failed to send network update, network {network}, update type: {update_type} \ + ({}), error: {err}", if update_type == 0 { "CREATE" } else { "MODIFY" }, ); error!(msg); @@ -699,26 +458,11 @@ impl GatewayUpdatesHandler { } } -// pub struct GatewayUpdatesStream { -// task_handle: JoinHandle<()>, -// rx: Receiver>, -// network_id: Id, -// gateway_hostname: String, -// gateway_state: Arc>, -// pool: PgPool, -// } - // #[tonic::async_trait] // impl gateway_service_server::GatewayService for GatewayServer { // type UpdatesStream = GatewayUpdatesStream; // // async fn updates(&self, request: Request<()>) -> Result, Status> { -// let GatewayMetadata { -// network_id, -// hostname, -// .. -// // info, -// } = Self::extract_metadata(request.metadata())?; // // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. // // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, // // version = version.to_string(), info); From 640bae9a0aea1e11395f0a29fb8c84eeefd7f115 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 3 Dec 2025 10:33:53 +0100 Subject: [PATCH 13/17] Resurrect gateway test --- Cargo.lock | 31 ++++++++----- crates/defguard_core/src/grpc/gateway/mod.rs | 8 ++-- .../defguard_core/src/grpc/gateway/state.rs | 14 ------ .../defguard_core/src/grpc/gateway/tests.rs | 44 +++++++++++++++---- .../tests/integration/grpc/common/mod.rs | 2 +- 5 files changed, 62 insertions(+), 37 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 009d6fab3..96317a4a6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -765,6 +765,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "convert_case" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633458d4ef8c78b72454de2d54fd6ab2e60f9e02be22f3c6104cdc8a4e0fceb9" +dependencies = [ + "unicode-segmentation", +] + [[package]] name = "cookie" version = "0.18.1" @@ -1379,7 +1388,7 @@ version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ - "convert_case", + "convert_case 0.4.0", "proc-macro2", "quote", "rustc_version", @@ -1388,21 +1397,23 @@ dependencies = [ [[package]] name = "derive_more" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678" +checksum = "10b768e943bed7bf2cab53df09f4bc34bfd217cdb57d971e769874c9a6710618" dependencies = [ "derive_more-impl", ] [[package]] name = "derive_more-impl" -version = "2.0.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" +checksum = "6d286bfdaf75e988b4a78e013ecd79c581e06399ab53fbacd2d916c2f904f30b" dependencies = [ + "convert_case 0.10.0", "proc-macro2", "quote", + "rustc_version", "syn", "unicode-xid", ] @@ -2705,9 +2716,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libgit2-sys" @@ -2798,9 +2809,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "lru-slab" @@ -3672,7 +3683,7 @@ dependencies = [ "curve25519-dalek", "cx448", "derive_builder", - "derive_more 2.0.1", + "derive_more 2.1.0", "des", "digest", "dsa", diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index 7be29eac1..afcdc81fb 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -32,17 +32,17 @@ pub mod client_state; pub(crate) mod handler; pub mod map; pub(crate) mod state; -// #[cfg(test)] -// mod tests; +#[cfg(test)] +mod tests; -const PEER_DISCONNECT_INTERVAL: u64 = 60; +#[cfg(test)] pub(super) static TONIC_SOCKET: &str = "tonic.sock"; /// Sends given `GatewayEvent` to be handled by gateway GRPC server /// /// If you want to use it inside the API context, use [`crate::AppState::send_wireguard_event`] instead pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { - debug!("Sending the following WireGuard event to the gateway: {event:?}"); + debug!("Sending the following WireGuard event to Defguard Gateway: {event:?}"); if let Err(err) = wg_tx.send(event) { error!("Error sending WireGuard event {err}"); } diff --git a/crates/defguard_core/src/grpc/gateway/state.rs b/crates/defguard_core/src/grpc/gateway/state.rs index 0f9b10f62..a628810f3 100644 --- a/crates/defguard_core/src/grpc/gateway/state.rs +++ b/crates/defguard_core/src/grpc/gateway/state.rs @@ -13,7 +13,6 @@ use utoipa::ToSchema; use uuid::Uuid; use crate::{ - // db::models::gateway::Gateway, grpc::MIN_GATEWAY_VERSION, handlers::mail::{send_gateway_disconnected_email, send_gateway_reconnected_email}, }; @@ -37,19 +36,6 @@ pub struct GatewayState { } impl GatewayState { - // pub(crate) fn from_gateway(gateway: &Gateway, network_name: &str) -> Self { - // Self { - // id: gateway.id, - // connected: gateway.is_connected(), - // network_id: gateway.network_id, - // network_name: network_name.to_owned(), - // name: None, // TODO: remove - // hostname: gateway.hostname.clone().unwrap_or_default(), - // connected_at: gateway.connected_at, - // disconnected_at: gateway.disconnected_at, - // } - // } - #[must_use] pub fn new>( network_id: Id, diff --git a/crates/defguard_core/src/grpc/gateway/tests.rs b/crates/defguard_core/src/grpc/gateway/tests.rs index f79b77dba..b00b1f004 100644 --- a/crates/defguard_core/src/grpc/gateway/tests.rs +++ b/crates/defguard_core/src/grpc/gateway/tests.rs @@ -1,18 +1,31 @@ use std::{ io, net::{IpAddr, Ipv4Addr}, + sync::{Arc, Mutex}, }; +use defguard_common::db::setup_pool; +use defguard_mail::Mail; +use defguard_proto::gateway::{CoreRequest, CoreResponse, gateway_server}; use ipnetwork::IpNetwork; +use sqlx::postgres::{PgConnectOptions, PgPoolOptions}; use tokio::{ net::UnixListener, sync::{broadcast, mpsc::unbounded_channel}, }; -use tokio_stream::wrappers::UnixListenerStream; +use tokio_stream::wrappers::{UnboundedReceiverStream, UnixListenerStream}; use tonic::{Request, Response, Status, Streaming, transport::Server}; -use super::*; +use super::{TONIC_SOCKET, handler::GatewayHandler}; +use crate::{ + db::models::{ + gateway::Gateway, + wireguard::{GatewayEvent, LocationMfaMode, ServiceLocationMode, WireguardNetwork}, + }, + grpc::{ClientMap, GrpcEvent}, +}; +// TODO: move to "gateway" repo. struct FakeGateway; #[tonic::async_trait] @@ -23,7 +36,7 @@ impl gateway_server::Gateway for FakeGateway { &self, request: Request>, ) -> Result, Status> { - let (_tx, rx) = mpsc::unbounded_channel(); + let (_tx, rx) = unbounded_channel(); let mut stream = request.into_inner(); tokio::spawn(async move { loop { @@ -55,17 +68,21 @@ async fn fake_gateway() -> Result<(), io::Error> { } #[sqlx::test] -async fn test_gateway(pool: PgPool) { +async fn test_gateway(_: PgPoolOptions, options: PgConnectOptions) { + let pool = setup_pool(options).await; let network = WireguardNetwork::new( "TestNet".to_string(), - IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap(), + vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 1)), 24).unwrap()], 50051, "0.0.0.0".to_string(), None, vec![IpNetwork::new(IpAddr::V4(Ipv4Addr::new(10, 1, 1, 0)), 24).unwrap()], - false, 0, 0, + false, + false, + LocationMfaMode::default(), + ServiceLocationMode::default(), ) .save(&pool) .await @@ -74,10 +91,21 @@ async fn test_gateway(pool: PgPool) { .save(&pool) .await .unwrap(); - let (events_tx, _events_rx) = broadcast::channel::(16); + let client_state = Arc::new(Mutex::new(ClientMap::new())); + let (events_tx, _events_rx) = broadcast::channel::(16); let (mail_tx, _mail_rx) = unbounded_channel::(); + let (grpc_event_tx, _grpc_event_rx) = unbounded_channel::(); - let mut gateway_handler = GatewayHandler::new(gateway, None, pool, events_tx, mail_tx).unwrap(); + let mut gateway_handler = GatewayHandler::new( + gateway, + None, + pool, + client_state, + events_tx, + mail_tx, + grpc_event_tx, + ) + .unwrap(); let handle = tokio::spawn(async move { gateway_handler.handle_connection().await; }); diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index b919afcd4..82525af4a 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -28,7 +28,7 @@ use tower::service_fn; use crate::common::{init_config, initialize_users}; -// pub mod mock_gateway; +pub mod mock_gateway; pub struct TestGrpcServer { grpc_server_task_handle: JoinHandle<()>, From ff62044a209a71d290f1af48a03cec23a4ed3aca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Wed, 3 Dec 2025 14:03:58 +0100 Subject: [PATCH 14/17] Do not extract metadata --- .../defguard_core/src/grpc/gateway/handler.rs | 59 ++++++++----------- crates/defguard_core/src/version.rs | 8 +-- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index bc3213522..eb2c81b58 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -90,17 +90,7 @@ impl GatewayHandler { }) } - fn get_network_id(metadata: &MetadataMap) -> Result { - match Self::get_network_id_from_metadata(metadata) { - Some(m) => Ok(m), - None => Err(Status::new( - Code::Internal, - "Network ID was not found in metadata", - )), - } - } - - // parse network id from gateway request metadata from intercepted information from JWT token + // Parse network ID from Gateway request metadata from intercepted information from JWT token. fn get_network_id_from_metadata(metadata: &MetadataMap) -> Option { if let Some(ascii_value) = metadata.get("gateway_network_id") { if let Ok(slice) = ascii_value.clone().to_str() { @@ -112,30 +102,28 @@ impl GatewayHandler { None } - // extract gateway hostname from request headers - fn get_gateway_hostname(metadata: &MetadataMap) -> Result { + // Extract Gateway hostname from request headers. + fn get_gateway_hostname(metadata: &MetadataMap) -> Option { match metadata.get("hostname") { Some(ascii_value) => { - let hostname = ascii_value.to_str().map_err(|_| { - Status::new( - Code::Internal, - "Failed to parse gateway hostname from request metadata", - ) - })?; - Ok(hostname.into()) + let Ok(hostname) = ascii_value.to_str() else { + error!("Failed to parse Gateway hostname from request metadata"); + return None; + }; + Some(hostname.into()) + } + None => { + error!("Gateway hostname not found in request metadata"); + None } - None => Err(Status::new( - Code::Internal, - "Gateway hostname not found in request metadata", - )), } } /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Result { + fn extract_metadata(metadata: &MetadataMap) -> Option { let (version, _info) = version_info_from_metadata(metadata); - Ok(GatewayMetadata { - network_id: Self::get_network_id(metadata)?, + Some(GatewayMetadata { + network_id: 0, // FIXME: not needed; was Self::get_network_id_from_metadata(metadata)?, hostname: Self::get_gateway_hostname(metadata)?, version, }) @@ -364,21 +352,20 @@ impl GatewayHandler { let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, Err(err) => { - error!("Failed to connect to gateway {uri}, retrying: {err}"); + error!("Failed to connect to Gateway {uri}, retrying: {err}"); sleep(TEN_SECS).await; continue; } }; info!("Connected to Defguard Gateway {uri}"); - let Ok(GatewayMetadata { - network_id, - hostname, - .. - // info, - }) = Self::extract_metadata(response.metadata()) else { - continue; - }; + // Metadata isn't needed in reversed communication. TODO: remove, but only check version. + // let Some(GatewayMetadata { + // hostname, + // }) = Self::extract_metadata(response.metadata()) else { + // error!("Failed to extract metadata"); + // continue; + // }; let mut resp_stream = response.into_inner(); let mut config_sent = false; diff --git a/crates/defguard_core/src/version.rs b/crates/defguard_core/src/version.rs index 849c23233..16043976f 100644 --- a/crates/defguard_core/src/version.rs +++ b/crates/defguard_core/src/version.rs @@ -10,7 +10,7 @@ use serde::Serialize; use tonic::{Status, service::Interceptor}; const MIN_PROXY_VERSION: Version = Version::new(1, 6, 0); -pub const MIN_GATEWAY_VERSION: Version = Version::new(1, 5, 0); +pub const MIN_GATEWAY_VERSION: Version = Version::new(1, 6, 0); static OUTDATED_COMPONENT_LIFETIME: TimeDelta = TimeDelta::hours(1); /// Checks if Defguard Proxy version meets minimum version requirements. @@ -110,7 +110,7 @@ impl Interceptor for GatewayVersionInterceptor { } } -#[derive(Debug, Default, Clone, Serialize)] +#[derive(Default, Clone, Serialize)] pub struct IncompatibleComponents { pub gateways: HashSet, pub proxy: Option, @@ -204,7 +204,7 @@ impl IncompatibleComponents { } } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Serialize)] pub struct IncompatibleGatewayData { pub version: Option, pub hostname: Option, @@ -261,7 +261,7 @@ impl IncompatibleGatewayData { } } -#[derive(Clone, Debug, Serialize)] +#[derive(Clone, Serialize)] pub struct IncompatibleProxyData { pub version: Option, created: NaiveDateTime, From 41681df5fcc73637765730d758fe316566d60d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 9 Dec 2025 10:47:59 +0100 Subject: [PATCH 15/17] Add version to gateway gRPC --- Cargo.lock | 48 +++++++++---------- .../defguard_core/src/grpc/gateway/handler.rs | 9 ++-- 2 files changed, 30 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96317a4a6..1dda62d88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -404,9 +404,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "base64ct" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" +checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "base64urlsafedata" @@ -574,9 +574,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.48" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c481bdbf0ed3b892f6f806287d72acd515b352a4ec27a208489b8c1bc839633a" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ "find-msvc-tools", "jobserver", @@ -1954,9 +1954,9 @@ dependencies = [ [[package]] name = "git2" -version = "0.20.2" +version = "0.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2deb07a133b1520dc1a5690e9bd08950108873d7ed5de38dcc74d3b5ebffa110" +checksum = "3e2b37e2f62729cdada11f0e6b3b6fe383c69c29fc619e391223e12856af308c" dependencies = [ "bitflags 2.10.0", "libc", @@ -2300,9 +2300,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ "base64 0.22.1", "bytes", @@ -2722,9 +2722,9 @@ checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libgit2-sys" -version = "0.18.2+1.9.1" +version = "0.18.3+1.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c42fe03df2bd3c53a3a9c7317ad91d80c81cd1fb0caec8d7cc4cd2bfa10c222" +checksum = "c9b3acc4b91781bb0b3386669d325163746af5f6e4f73e6d2d630e09a35f3487" dependencies = [ "cc", "libc", @@ -2761,9 +2761,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "840db8cf39d9ec4dd794376f38acc40d0fc65eec2a8f484f7fd375b84602becd" +checksum = "8b484ba8d4f775eeca644c452a56650e544bf7e617f1d170fe7298122ead5222" dependencies = [ "zlib-rs", ] @@ -2933,9 +2933,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi", @@ -4222,9 +4222,9 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a" dependencies = [ "base64 0.22.1", "bytes", @@ -4882,9 +4882,9 @@ dependencies = [ [[package]] name = "simd-adler32" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" [[package]] name = "simple_asn1" @@ -5630,9 +5630,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.23.7" +version = "0.23.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6485ef6d0d9b5d0ec17244ff7eb05310113c3f316f2d14200d4de56b3cb98f8d" +checksum = "5d7cbc3b4b49633d57a0509303158ca50de80ae32c265093b24c414705807832" dependencies = [ "indexmap 2.12.1", "toml_datetime", @@ -5766,9 +5766,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags 2.10.0", "bytes", @@ -6914,9 +6914,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f06ae92f42f5e5c42443fd094f245eb656abf56dd7cce9b8b263236565e00f2" +checksum = "36134c44663532e6519d7a6dfdbbe06f6f8192bde8ae9ed076e9b213f0e31df7" [[package]] name = "zopfli" diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index eb2c81b58..8d6bb6b8a 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -8,10 +8,10 @@ use std::{ }; use chrono::{TimeDelta, Utc}; -use defguard_common::{auth::claims::Claims, db::Id}; +use defguard_common::{VERSION, auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; -use defguard_version::version_info_from_metadata; +use defguard_version::{client::ClientVersionInterceptor, version_info_from_metadata}; use semver::Version; use sqlx::PgPool; use tokio::{ @@ -347,7 +347,10 @@ impl GatewayHandler { )); debug!("Connecting to Gateway {uri}"); - let mut client = gateway_client::GatewayClient::new(channel); + let interceptor = ClientVersionInterceptor::new( + Version::parse(VERSION).expect("failed to parse self version"), + ); + let mut client = gateway_client::GatewayClient::with_interceptor(channel, interceptor); let (tx, rx) = mpsc::unbounded_channel(); let response = match client.bidi(UnboundedReceiverStream::new(rx)).await { Ok(response) => response, From 3b1f0973ca6ec80cb26c11f48ec2af72efaf38d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 9 Dec 2025 13:07:41 +0100 Subject: [PATCH 16/17] Cleanup --- crates/defguard_core/src/grpc/client_mfa.rs | 4 +- .../defguard_core/src/grpc/gateway/handler.rs | 67 +++-------- crates/defguard_core/src/grpc/gateway/mod.rs | 106 ++---------------- .../tests/integration/grpc/common/mod.rs | 5 - 4 files changed, 28 insertions(+), 154 deletions(-) diff --git a/crates/defguard_core/src/grpc/client_mfa.rs b/crates/defguard_core/src/grpc/client_mfa.rs index f688a41a4..90abe62f4 100644 --- a/crates/defguard_core/src/grpc/client_mfa.rs +++ b/crates/defguard_core/src/grpc/client_mfa.rs @@ -484,7 +484,7 @@ impl ClientMfaServer { } MfaMethod::Totp => { let code = if let Some(code) = request.code { - code.to_string() + code.clone() } else { error!("TOTP code not provided in request"); self.emit_event(BidiStreamEvent { @@ -518,7 +518,7 @@ impl ClientMfaServer { } MfaMethod::Email => { let code = if let Some(code) = request.code { - code.to_string() + code.clone() } else { error!("Email MFA code not provided in request"); self.emit_event(BidiStreamEvent { diff --git a/crates/defguard_core/src/grpc/gateway/handler.rs b/crates/defguard_core/src/grpc/gateway/handler.rs index 8d6bb6b8a..896838631 100644 --- a/crates/defguard_core/src/grpc/gateway/handler.rs +++ b/crates/defguard_core/src/grpc/gateway/handler.rs @@ -11,7 +11,7 @@ use chrono::{TimeDelta, Utc}; use defguard_common::{VERSION, auth::claims::Claims, db::Id}; use defguard_mail::Mail; use defguard_proto::gateway::{CoreResponse, core_request, core_response, gateway_client}; -use defguard_version::{client::ClientVersionInterceptor, version_info_from_metadata}; +use defguard_version::client::ClientVersionInterceptor; use semver::Version; use sqlx::PgPool; use tokio::{ @@ -50,14 +50,6 @@ pub(crate) struct GatewayHandler { grpc_event_tx: UnboundedSender, } -/// Utility struct encapsulating commonly extracted metadata fields during gRPC communication. -struct GatewayMetadata { - network_id: Id, - hostname: String, - version: Version, - // info: String, -} - impl GatewayHandler { pub(crate) fn new( gateway: Gateway, @@ -68,7 +60,7 @@ impl GatewayHandler { mail_tx: UnboundedSender, grpc_event_tx: UnboundedSender, ) -> Result { - let endpoint = Endpoint::from_shared(gateway.url.to_string())? + let endpoint = Endpoint::from_shared(gateway.url.clone())? .http2_keep_alive_interval(TEN_SECS) .tcp_keepalive(Some(TEN_SECS)) .keep_alive_while_idle(true); @@ -104,31 +96,18 @@ impl GatewayHandler { // Extract Gateway hostname from request headers. fn get_gateway_hostname(metadata: &MetadataMap) -> Option { - match metadata.get("hostname") { - Some(ascii_value) => { - let Ok(hostname) = ascii_value.to_str() else { - error!("Failed to parse Gateway hostname from request metadata"); - return None; - }; - Some(hostname.into()) - } - None => { - error!("Gateway hostname not found in request metadata"); - None - } + if let Some(ascii_value) = metadata.get("hostname") { + let Ok(hostname) = ascii_value.to_str() else { + error!("Failed to parse Gateway hostname from request metadata"); + return None; + }; + Some(hostname.into()) + } else { + error!("Gateway hostname not found in request metadata"); + None } } - /// Utility function extracting metadata fields during gRPC communication. - fn extract_metadata(metadata: &MetadataMap) -> Option { - let (version, _info) = version_info_from_metadata(metadata); - Some(GatewayMetadata { - network_id: 0, // FIXME: not needed; was Self::get_network_id_from_metadata(metadata)?, - hostname: Self::get_gateway_hostname(metadata)?, - version, - }) - } - /// Send network and VPN configuration to Gateway. async fn send_configuration( &self, @@ -179,7 +158,7 @@ impl GatewayHandler { let maybe_firewall_config = network - .try_get_firewall_config(&mut *conn) + .try_get_firewall_config(&mut conn) .await .map_err(|err| { error!("Failed to generate firewall config for network {network_id}: {err}"); @@ -255,7 +234,7 @@ impl GatewayHandler { "{} disconnected. Email notification not sent.", self.gateway ); - }; + } } /// Helper method to fetch `Device` info from DB by pubkey and return appropriate errors @@ -360,15 +339,7 @@ impl GatewayHandler { continue; } }; - info!("Connected to Defguard Gateway {uri}"); - // Metadata isn't needed in reversed communication. TODO: remove, but only check version. - // let Some(GatewayMetadata { - // hostname, - // }) = Self::extract_metadata(response.metadata()) else { - // error!("Failed to extract metadata"); - // continue; - // }; let mut resp_stream = response.into_inner(); let mut config_sent = false; @@ -431,20 +402,19 @@ impl GatewayHandler { .gateway .touch_connected(&self.pool, config_request.hostname) .await; - let mut guh = super::GatewayUpdatesHandler::new( + let mut updates_handler = super::GatewayUpdatesHandler::new( self.gateway.network_id, network, self.gateway .hostname - .as_ref() - .cloned() + .clone() .unwrap_or_default() .clone(), self.events_tx.subscribe(), tx.clone(), ); tokio::spawn(async move { - guh.run().await; + updates_handler.run().await; }); } Err(err) => { @@ -548,8 +518,7 @@ impl GatewayHandler { &self .gateway .hostname - .as_ref() - .cloned() + .clone() .unwrap_or_default(), &public_key, &device, @@ -617,7 +586,7 @@ impl GatewayHandler { debug!("WireGuard peer stats: {stats:?}"); } None => (), - }; + } } Err(err) => { error!("Disconnected from Gateway at {uri}, error: {err}"); diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index afcdc81fb..b1e48deca 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,31 +1,23 @@ -use std::{ - net::IpAddr, - sync::{Arc, Mutex}, -}; +use std::net::IpAddr; use chrono::{DateTime, Utc}; -use client_state::ClientMap; use defguard_common::db::{Id, NoId}; -use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use sqlx::PgPool; -use thiserror::Error; use tokio::sync::{ broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::{UnboundedSender, error::SendError}, + mpsc::UnboundedSender, }; use tonic::{Code, Status}; -use self::map::GatewayMap; use crate::{ db::{ GatewayEvent, models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, }, - events::{GrpcEvent, GrpcRequestContext}, + events::GrpcRequestContext, }; pub mod client_state; @@ -52,36 +44,12 @@ pub fn send_wireguard_event(event: GatewayEvent, wg_tx: &Sender) { /// /// If you want to use it inside the API context, use [`crate::AppState::send_multiple_wireguard_events`] instead pub fn send_multiple_wireguard_events(events: Vec, wg_tx: &Sender) { - debug!("Sending {} wireguard events", events.len()); + debug!("Sending {} WireGuard events", events.len()); for event in events { send_wireguard_event(event, wg_tx); } } -#[allow(clippy::large_enum_variant)] -#[derive(Debug, Error)] -pub enum GatewayServerError { - #[error("Failed to acquire lock on VPN client state map")] - ClientStateMutexError, - #[error("gRPC event channel error: {0}")] - GrpcEventChannelError(#[from] SendError), -} - -impl From for Status { - fn from(value: GatewayServerError) -> Self { - Self::new(Code::Internal, value.to_string()) - } -} - -pub struct GatewayServer { - pool: PgPool, - gateway_state: Arc>, - client_state: Arc>, - wireguard_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, -} - fn gen_config( network: &WireguardNetwork, peers: Vec, @@ -199,7 +167,8 @@ impl GatewayUpdatesHandler { Some(network_info) => { if self.network.mfa_enabled() && !network_info.is_authorized { debug!( - "Created WireGuard device {} is not authorized to connect to MFA enabled location {}", + "Created WireGuard device {} is not authorized to connect to \ + MFA enabled location {}", device.device.name, self.network.name ); continue; @@ -234,7 +203,8 @@ impl GatewayUpdatesHandler { Some(network_info) => { if self.network.mfa_enabled() && !network_info.is_authorized { debug!( - "Modified WireGuard device {} is not authorized to connect to MFA enabled location {}", + "Modified WireGuard device {} is not authorized to connect to \ + MFA enabled location {}", device.device.name, self.network.name ); continue; @@ -457,63 +427,3 @@ impl GatewayUpdatesHandler { Ok(()) } } - -// #[tonic::async_trait] -// impl gateway_service_server::GatewayService for GatewayServer { -// type UpdatesStream = GatewayUpdatesStream; -// -// async fn updates(&self, request: Request<()>) -> Result, Status> { -// // FIXME: tracing causes looping messages, like `INFO gateway_config:gateway_stats:...`. -// // let span = tracing::info_span!("gateway_updates", component = %DefguardComponent::Gateway, -// // version = version.to_string(), info); -// // let _guard = span.enter(); - -// let Some(network) = WireguardNetwork::find_by_id(&self.pool, network_id) -// .await -// .map_err(|_| { -// error!("Failed to fetch network {network_id} from the database"); -// Status::new( -// Code::Internal, -// format!("Failed to retrieve network {network_id} from the database"), -// ) -// })? -// else { -// return Err(Status::new( -// Code::Internal, -// format!("Network with id {network_id} not found"), -// )); -// }; - -// info!("New client connected to updates stream: {hostname}, network {network}",); - -// let (tx, rx) = mpsc::channel(4); -// let events_rx = self.wireguard_tx.subscribe(); -// let mut state = self.gateway_state.lock().unwrap(); -// state -// .connect_gateway(network_id, &hostname, &self.pool) -// .map_err(|err| { -// error!("Failed to connect gateway on network {network_id}: {err}"); -// Status::new( -// Code::Internal, -// format!("Failed to connect gateway on network {network_id}"), -// ) -// })?; - -// // clone here before moving into a closure -// let gateway_hostname = hostname.clone(); -// let handle = tokio::spawn(async move { -// let mut update_handler = -// GatewayUpdatesHandler::new(network_id, network, gateway_hostname, events_rx, tx); -// update_handler.run().await; -// }); - -// Ok(Response::new(GatewayUpdatesStream::new( -// handle, -// rx, -// network_id, -// hostname, -// Arc::clone(&self.gateway_state), -// self.pool.clone(), -// ))) -// } -// } diff --git a/crates/defguard_core/tests/integration/grpc/common/mod.rs b/crates/defguard_core/tests/integration/grpc/common/mod.rs index 82525af4a..d4ca1d0b1 100644 --- a/crates/defguard_core/tests/integration/grpc/common/mod.rs +++ b/crates/defguard_core/tests/integration/grpc/common/mod.rs @@ -156,13 +156,8 @@ pub(crate) async fn make_grpc_test_server(pool: &PgPool) -> TestGrpcServer { server, pool.clone(), worker_state, - // gateway_state.clone(), - // client_state.clone(), - // wg_tx.clone(), mail_tx, failed_logins, - // grpc_event_tx, - // Default::default(), ) .await .unwrap(); From 9e36670758fbce5f8807cd460e3bb23e0e20aad9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Ciarcin=CC=81ski?= Date: Tue, 16 Dec 2025 14:45:16 +0100 Subject: [PATCH 17/17] Re-organise --- Cargo.lock | 44 +++--- crates/defguard/src/main.rs | 4 +- .../defguard_core/src/db/models/wireguard.rs | 4 +- crates/defguard_core/src/grpc/gateway/mod.rs | 126 +++++++++++++++++- crates/defguard_core/src/grpc/mod.rs | 110 +-------------- .../defguard_core/src/handlers/wireguard.rs | 90 +++++-------- crates/defguard_version/src/lib.rs | 2 +- 7 files changed, 181 insertions(+), 199 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1dda62d88..a401d595c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -410,9 +410,9 @@ checksum = "0e050f626429857a27ddccb31e0aca21356bfa709c04041aefddac081a8f068a" [[package]] name = "base64urlsafedata" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "215ee31f8a88f588c349ce2d20108b2ed96089b96b9c2b03775dc35dd72938e8" +checksum = "42f7f6be94fa637132933fd0a68b9140bcb60e3d46164cb68e82a2bb8d102b3a" dependencies = [ "base64 0.21.7", "pastey", @@ -2396,9 +2396,9 @@ checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a" [[package]] name = "icu_properties" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99" +checksum = "020bfc02fe870ec3a66d93e677ccca0562506e5872c650f893269e08615d74ec" dependencies = [ "icu_collections", "icu_locale_core", @@ -2410,9 +2410,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899" +checksum = "616c294cf8d725c6afcd8f55abc17c56464ef6211f9ed59cccffe534129c77af" [[package]] name = "icu_provider" @@ -2761,9 +2761,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b484ba8d4f775eeca644c452a56650e544bf7e617f1d170fe7298122ead5222" +checksum = "15413ef615ad868d4d65dce091cb233b229419c7c0c4bcaa746c0901c49ff39c" dependencies = [ "zlib-rs", ] @@ -6285,9 +6285,9 @@ dependencies = [ [[package]] name = "webauthn-attestation-ca" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f77a2892ec44032e6c48dad9aad1b05fada09c346ada11d8d32db119b4b4f205" +checksum = "fafcf13f7dc1fb292ed4aea22cdd3757c285d7559e9748950ee390249da4da6b" dependencies = [ "base64urlsafedata", "openssl", @@ -6299,9 +6299,9 @@ dependencies = [ [[package]] name = "webauthn-authenticator-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45f8fe3811c8d6c6830d263452670a608fd4dcdfc481349bd4d1e6a46d6c7a0f" +checksum = "78b41ed08aba475a969094226ae0691a286686210ae497bb2c5d0ed722d8d526" dependencies = [ "async-stream", "async-trait", @@ -6332,9 +6332,9 @@ dependencies = [ [[package]] name = "webauthn-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb7c3a2f9c8bddd524e47bbd427bcf3a28aa074de55d74470b42a91a41937b8e" +checksum = "1b24d082d3360258fefb6ffe56123beef7d6868c765c779f97b7a2fcf06727f8" dependencies = [ "base64urlsafedata", "serde", @@ -6346,9 +6346,9 @@ dependencies = [ [[package]] name = "webauthn-rs-core" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19f1d80f3146382529fe70a3ab5d0feb2413a015204ed7843f9377cd39357fc4" +checksum = "15784340a24c170ce60567282fb956a0938742dbfbf9eff5df793a686a009b8b" dependencies = [ "base64 0.21.7", "base64urlsafedata", @@ -6357,8 +6357,8 @@ dependencies = [ "nom 7.1.3", "openssl", "openssl-sys", - "rand 0.8.5", - "rand_chacha 0.3.1", + "rand 0.9.2", + "rand_chacha 0.9.0", "serde", "serde_cbor_2 0.13.0", "serde_json", @@ -6373,9 +6373,9 @@ dependencies = [ [[package]] name = "webauthn-rs-proto" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e786894f89facb9aaf1c5f6559670236723c98382e045521c76f3d5ca5047bd" +checksum = "16a1fb2580ce73baa42d3011a24de2ceab0d428de1879ece06e02e8c416e497c" dependencies = [ "base64 0.21.7", "base64urlsafedata", @@ -6914,9 +6914,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36134c44663532e6519d7a6dfdbbe06f6f8192bde8ae9ed076e9b213f0e31df7" +checksum = "51f936044d677be1a1168fae1d03b583a285a5dd9d8cbf7b24c23aa1fc775235" [[package]] name = "zopfli" diff --git a/crates/defguard/src/main.rs b/crates/defguard/src/main.rs index 11f25325a..070c652ba 100644 --- a/crates/defguard/src/main.rs +++ b/crates/defguard/src/main.rs @@ -23,8 +23,8 @@ use defguard_core::{ events::{ApiEvent, BidiStreamEvent, GrpcEvent, InternalEvent}, grpc::{ WorkerState, - gateway::{client_state::ClientMap, map::GatewayMap}, - run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server, + gateway::{client_state::ClientMap, map::GatewayMap, run_grpc_gateway_stream}, + run_grpc_bidi_stream, run_grpc_server, }, init_dev_env, init_vpn_location, run_web_server, utility_thread::run_utility_thread, diff --git a/crates/defguard_core/src/db/models/wireguard.rs b/crates/defguard_core/src/db/models/wireguard.rs index 32c4a4e4f..7ab886b6b 100644 --- a/crates/defguard_core/src/db/models/wireguard.rs +++ b/crates/defguard_core/src/db/models/wireguard.rs @@ -41,7 +41,7 @@ use super::{ }; use crate::{ enterprise::{firewall::FirewallError, is_enterprise_enabled}, - grpc::gateway::{send_multiple_wireguard_events, state::GatewayState}, + grpc::gateway::send_multiple_wireguard_events, wg_config::ImportedDevice, }; @@ -1449,7 +1449,7 @@ pub struct WireguardNetworkInfo { #[serde(flatten)] pub network: WireguardNetwork, pub connected: bool, - pub gateways: Vec, + // pub gateways: Vec, pub allowed_groups: Vec, } diff --git a/crates/defguard_core/src/grpc/gateway/mod.rs b/crates/defguard_core/src/grpc/gateway/mod.rs index b1e48deca..9b9809dc8 100644 --- a/crates/defguard_core/src/grpc/gateway/mod.rs +++ b/crates/defguard_core/src/grpc/gateway/mod.rs @@ -1,23 +1,38 @@ -use std::net::IpAddr; +use std::{ + collections::HashMap, + net::IpAddr, + sync::{Arc, Mutex}, +}; use chrono::{DateTime, Utc}; -use defguard_common::db::{Id, NoId}; +use defguard_common::{ + config::server_config, + db::{ChangeNotification, Id, NoId, TriggerOperation}, +}; +use defguard_mail::Mail; use defguard_proto::{ enterprise::firewall::FirewallConfig, gateway::{Configuration, CoreResponse, Peer, PeerStats, Update, core_response, update}, }; -use tokio::sync::{ - broadcast::{Receiver as BroadcastReceiver, Sender}, - mpsc::UnboundedSender, +use sqlx::{PgPool, postgres::PgListener}; +use tokio::{ + sync::{ + broadcast::{Receiver as BroadcastReceiver, Sender}, + mpsc::UnboundedSender, + }, + task::{AbortHandle, JoinSet}, }; use tonic::{Code, Status}; use crate::{ db::{ GatewayEvent, - models::{wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats}, + models::{ + gateway::Gateway, wireguard::WireguardNetwork, wireguard_peer_stats::WireguardPeerStats, + }, }, - events::GrpcRequestContext, + events::{GrpcEvent, GrpcRequestContext}, + grpc::gateway::{client_state::ClientMap, handler::GatewayHandler}, }; pub mod client_state; @@ -87,6 +102,103 @@ impl WireguardPeerStats { } } +const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; + +/// Bi-directional gRPC stream for comminication with Defguard Gateway. +pub async fn run_grpc_gateway_stream( + pool: PgPool, + client_state: Arc>, + events_tx: Sender, + mail_tx: UnboundedSender, + grpc_event_tx: UnboundedSender, +) -> Result<(), anyhow::Error> { + let config = server_config(); + let tls_config = config.grpc_client_tls_config()?; + + let mut abort_handles = HashMap::new(); + + let mut tasks = JoinSet::new(); + // Helper closure to launch `GatewayHandler`. + let mut launch_gateway_handler = + |gateway: Gateway| -> Result { + let mut gateway_handler = GatewayHandler::new( + gateway, + tls_config.clone(), + pool.clone(), + Arc::clone(&client_state), + events_tx.clone(), + mail_tx.clone(), + grpc_event_tx.clone(), + )?; + let abort_handle = tasks.spawn(async move { + gateway_handler.handle_connection().await; + }); + Ok(abort_handle) + }; + + for gateway in Gateway::all(&pool).await? { + let id = gateway.id; + let abort_handle = launch_gateway_handler(gateway)?; + abort_handles.insert(id, abort_handle); + } + + // Observe gateway URL changes. + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(GATEWAY_TABLE_TRIGGER).await?; + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + match serde_json::from_str::>>(payload) { + Ok(gateway_notification) => match gateway_notification.operation { + TriggerOperation::Insert => { + if let Some(new) = gateway_notification.new { + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } + } + TriggerOperation::Update => { + if let (Some(old), Some(new)) = + (gateway_notification.old, gateway_notification.new) + { + if old.url == new.url { + debug!( + "Gateway URL didn't change. Keeping the current gateway handler" + ); + } else if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!("Aborting connection to {old}, it has changed in the database"); + abort_handle.abort(); + let id = new.id; + let abort_handle = launch_gateway_handler(new)?; + abort_handles.insert(id, abort_handle); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + TriggerOperation::Delete => { + if let Some(old) = gateway_notification.old { + if let Some(abort_handle) = abort_handles.remove(&old.id) { + info!( + "Aborting connection to {old}, it has disappeard from the database" + ); + abort_handle.abort(); + } else { + warn!("Cannot find {old} on the list of connected gateways"); + } + } + } + }, + Err(err) => error!("Failed to de-serialize database notification object: {err}"), + } + } + + while let Some(Ok(_result)) = tasks.join_next().await { + debug!("Gateway gRPC task has ended"); + } + + Ok(()) +} + /// Helper struct for handling gateway events. struct GatewayUpdatesHandler { network_id: Id, diff --git a/crates/defguard_core/src/grpc/mod.rs b/crates/defguard_core/src/grpc/mod.rs index 8b9e126ee..fcea09feb 100644 --- a/crates/defguard_core/src/grpc/mod.rs +++ b/crates/defguard_core/src/grpc/mod.rs @@ -10,7 +10,7 @@ use axum::http::Uri; use defguard_common::{ VERSION, auth::claims::ClaimsType, - db::{ChangeNotification, Id, TriggerOperation, models::Settings}, + db::{Id, models::Settings}, }; use defguard_mail::Mail; use defguard_version::{ @@ -20,13 +20,12 @@ use defguard_version::{ use openidconnect::{AuthorizationCode, Nonce, Scope, core::CoreAuthenticationFlow}; use reqwest::Url; use serde::Serialize; -use sqlx::{PgPool, postgres::PgListener}; +use sqlx::PgPool; use tokio::{ sync::{ broadcast::Sender, mpsc::{self, UnboundedSender}, }, - task::{AbortHandle, JoinSet}, time::sleep, }; use tokio_stream::wrappers::UnboundedReceiverStream; @@ -39,17 +38,13 @@ use tonic::{ use self::{ auth::AuthServer, client_mfa::ClientMfaServer, enrollment::EnrollmentServer, - gateway::handler::GatewayHandler, interceptor::JwtInterceptor, - password_reset::PasswordResetServer, worker::WorkerServer, + interceptor::JwtInterceptor, password_reset::PasswordResetServer, worker::WorkerServer, }; use crate::{ auth::failed_login::FailedLoginMap, db::{ AppEvent, GatewayEvent, - models::{ - enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, - gateway::Gateway, - }, + models::enrollment::{ENROLLMENT_TOKEN_TYPE, Token}, }, enterprise::{ db::models::{ @@ -551,103 +546,6 @@ async fn handle_proxy_message_loop( Ok(()) } -const GATEWAY_TABLE_TRIGGER: &str = "gateway_change"; - -/// Bi-directional gRPC stream for comminication with Defguard Gateway. -pub async fn run_grpc_gateway_stream( - pool: PgPool, - client_state: Arc>, - events_tx: Sender, - mail_tx: UnboundedSender, - grpc_event_tx: UnboundedSender, -) -> Result<(), anyhow::Error> { - let config = server_config(); - let tls_config = config.grpc_client_tls_config()?; - - let mut abort_handles = HashMap::new(); - - let mut tasks = JoinSet::new(); - // Helper closure to launch `GatewayHandler`. - let mut launch_gateway_handler = - |gateway: Gateway| -> Result { - let mut gateway_handler = GatewayHandler::new( - gateway, - tls_config.clone(), - pool.clone(), - Arc::clone(&client_state), - events_tx.clone(), - mail_tx.clone(), - grpc_event_tx.clone(), - )?; - let abort_handle = tasks.spawn(async move { - gateway_handler.handle_connection().await; - }); - Ok(abort_handle) - }; - - for gateway in Gateway::all(&pool).await? { - let id = gateway.id; - let abort_handle = launch_gateway_handler(gateway)?; - abort_handles.insert(id, abort_handle); - } - - // Observe gateway URL changes. - let mut listener = PgListener::connect_with(&pool).await?; - listener.listen(GATEWAY_TABLE_TRIGGER).await?; - while let Ok(notification) = listener.recv().await { - let payload = notification.payload(); - match serde_json::from_str::>>(payload) { - Ok(gateway_notification) => match gateway_notification.operation { - TriggerOperation::Insert => { - if let Some(new) = gateway_notification.new { - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); - } - } - TriggerOperation::Update => { - if let (Some(old), Some(new)) = - (gateway_notification.old, gateway_notification.new) - { - if old.url == new.url { - debug!( - "Gateway URL didn't change. Keeping the current gateway handler" - ); - } else if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!("Aborting connection to {old}, it has changed in the database"); - abort_handle.abort(); - let id = new.id; - let abort_handle = launch_gateway_handler(new)?; - abort_handles.insert(id, abort_handle); - } else { - warn!("Cannot find {old} on the list of connected gateways"); - } - } - } - TriggerOperation::Delete => { - if let Some(old) = gateway_notification.old { - if let Some(abort_handle) = abort_handles.remove(&old.id) { - info!( - "Aborting connection to {old}, it has disappeard from the database" - ); - abort_handle.abort(); - } else { - warn!("Cannot find {old} on the list of connected gateways"); - } - } - } - }, - Err(err) => error!("Failed to de-serialize database notification object: {err}"), - } - } - - while let Some(Ok(_result)) = tasks.join_next().await { - debug!("Gateway gRPC task has ended"); - } - - Ok(()) -} - /// Bi-directional gRPC stream for communication with Defguard Proxy. #[instrument(skip_all)] pub async fn run_grpc_bidi_stream( diff --git a/crates/defguard_core/src/handlers/wireguard.rs b/crates/defguard_core/src/handlers/wireguard.rs index 9410134bd..091ebd439 100644 --- a/crates/defguard_core/src/handlers/wireguard.rs +++ b/crates/defguard_core/src/handlers/wireguard.rs @@ -1,12 +1,6 @@ -use std::{ - collections::HashSet, - net::IpAddr, - str::FromStr, - sync::{Arc, Mutex}, -}; +use std::{collections::HashSet, net::IpAddr, str::FromStr}; use axum::{ - Extension, extract::{Json, Path, Query, State}, http::StatusCode, }; @@ -17,7 +11,6 @@ use ipnetwork::IpNetwork; use serde_json::{Value, json}; use sqlx::PgPool; use utoipa::ToSchema; -use uuid::Uuid; use super::{ApiResponse, ApiResult, WebError, device_for_admin_or_self, user_for_admin_or_self}; use crate::{ @@ -44,7 +37,6 @@ use crate::{ limits::update_counts, }, events::{ApiEvent, ApiEventType, ApiRequestContext}, - grpc::gateway::map::GatewayMap, handlers::mail::send_new_device_added_email, server_config, wg_config::{ImportedDevice, parse_wireguard_config}, @@ -445,11 +437,7 @@ pub(crate) async fn delete_network( ("api_token" = []) ) )] -pub(crate) async fn list_networks( - _role: AdminRole, - State(appstate): State, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn list_networks(_role: AdminRole, State(appstate): State) -> ApiResult { debug!("Listing WireGuard networks"); let mut network_info = Vec::new(); let networks = WireguardNetwork::all(&appstate.pool).await?; @@ -458,13 +446,10 @@ pub(crate) async fn list_networks( let network_id = network.id; let allowed_groups = network.fetch_allowed_groups(&appstate.pool).await?; { - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); network_info.push(WireguardNetworkInfo { network, - connected: gateway_state.connected(network_id), - gateways: gateway_state.get_network_gateway_status(network_id), + connected: false, // FIXME: was: gateway_state.connected(network_id), + // gateways: gateway_state.get_network_gateway_status(network_id), allowed_groups, }); } @@ -504,20 +489,16 @@ pub(crate) async fn network_details( Path(network_id): Path, _role: AdminRole, State(appstate): State, - Extension(gateway_state): Extension>>, ) -> ApiResult { debug!("Displaying network details for network {network_id}"); let network = WireguardNetwork::find_by_id(&appstate.pool, network_id).await?; let response = match network { Some(network) => { let allowed_groups = network.fetch_allowed_groups(&appstate.pool).await?; - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); let network_info = WireguardNetworkInfo { network, - connected: gateway_state.connected(network_id), - gateways: gateway_state.get_network_gateway_status(network_id), + connected: false, // FIXME: was: gateway_state.connected(network_id), + // gateways: gateway_state.get_network_gateway_status(network_id), allowed_groups, }; ApiResponse { @@ -539,56 +520,47 @@ pub(crate) async fn network_details( /// /// # Returns /// Returns `Vec` for requested network -pub(crate) async fn gateway_status( - Path(network_id): Path, - _role: AdminRole, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn gateway_status(Path(network_id): Path, _role: AdminRole) -> ApiResult { debug!("Displaying gateway status for network {network_id}"); - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); + + // TODO: fetch gateways from db + debug!("Displayed gateway status for network {network_id}"); - Ok(ApiResponse { - json: json!(gateway_state.get_network_gateway_status(network_id)), - status: StatusCode::OK, - }) + // Ok(ApiResponse { + // json: json!(gateway_state.get_network_gateway_status(network_id)), + // status: StatusCode::OK, + // }) + Ok(ApiResponse::default()) } /// Returns state of gateways for all networks /// /// Returns current state of gateways as `HashMap>` where key is an id of `WireguardNetwork` -pub(crate) async fn all_gateways_status( - _role: AdminRole, - Extension(gateway_state): Extension>>, -) -> ApiResult { +pub(crate) async fn all_gateways_status(_role: AdminRole) -> ApiResult { debug!("Displaying gateways status for all networks."); - let gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); - let flattened = (*gateway_state).as_flattened(); - Ok(ApiResponse { - json: json!(flattened), - status: StatusCode::OK, - }) + + // let flattened = (*gateway_state).as_flattened(); + // Ok(ApiResponse { + // json: json!(flattened), + // status: StatusCode::OK, + // }) + Ok(ApiResponse::default()) } pub(crate) async fn remove_gateway( Path((network_id, gateway_id)): Path<(i64, String)>, _role: AdminRole, - Extension(gateway_state): Extension>>, ) -> ApiResult { debug!("Removing gateway {gateway_id} in network {network_id}"); - let mut gateway_state = gateway_state - .lock() - .expect("Failed to acquire gateway state lock"); - - gateway_state.remove_gateway( - network_id, - Uuid::from_str(&gateway_id) - .map_err(|_| WebError::Http(StatusCode::INTERNAL_SERVER_ERROR))?, - )?; + + // TODO: fetch gateways from db + + // gateway_state.remove_gateway( + // network_id, + // Uuid::from_str(&gateway_id) + // .map_err(|_| WebError::Http(StatusCode::INTERNAL_SERVER_ERROR))?, + // )?; info!("Removed gateway {gateway_id} in network {network_id}"); diff --git a/crates/defguard_version/src/lib.rs b/crates/defguard_version/src/lib.rs index 05f177b24..8d6881440 100644 --- a/crates/defguard_version/src/lib.rs +++ b/crates/defguard_version/src/lib.rs @@ -62,7 +62,7 @@ use std::{cmp::Ordering, fmt, str::FromStr}; -use ::tracing::{error, warn}; +use ::tracing::warn; pub use semver::{BuildMetadata, Error as SemverError, Prerelease, Version}; use serde::Serialize; use thiserror::Error;