diff --git a/go.mod b/go.mod index 870118c88e7..8f4ec530b74 100644 --- a/go.mod +++ b/go.mod @@ -64,7 +64,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba + github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 diff --git a/go.sum b/go.sum index 96a303f798d..f10e1e6da59 100644 --- a/go.sum +++ b/go.sum @@ -368,8 +368,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba h1:pD6eygRJ5EYAlgzeNskPU3WqszMz6/HhPuc6/Bc/580= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251202114414-534cf891e0ba/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847 h1:V0zsYYMU5d2UN1m9zOLPEZCGWpnhtkYcxQVi9Rrx3bY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251203183432-d5400f030847/go.mod h1:qzLCKeR253jtsWhfZTt4fyegI5zei32jKZykV+oSQOo= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= diff --git a/management/server/account.go b/management/server/account.go index dac040db0f5..b97fc7b0fe3 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -295,10 +295,23 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } - if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil { return err } + if oldSettings.Extra != nil && newSettings.Extra != nil && + oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled { + approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to approve pending peers: %w", err) + } + + if approvedCount > 0 { + log.WithContext(ctx).Debugf("approved %d pending peers in account %s", approvedCount, accountID) + updateAccountPeers = true + } + } + if oldSettings.NetworkRange != newSettings.NetworkRange { if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil { return err @@ -372,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } -func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -386,17 +399,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") - if err != nil { - return err - } - - peersMap := make(map[string]*nbpeer.Peer, len(peers)) - for _, peer := range peers { - peersMap[peer.ID] = peer - } - - return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { diff --git a/management/server/account_test.go b/management/server/account_test.go index 8569f1b2fb9..7f125e3a0c8 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2058,6 +2058,43 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } +func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) { + manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + accountID := account.Id + userID := account.Users[account.CreatedBy].Id + ctx := context.Background() + + newSettings := account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: true, + } + _, err := manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + peer1.Status.RequiresApproval = true + peer2.Status.RequiresApproval = true + peer3.Status.RequiresApproval = false + + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer1)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer2)) + require.NoError(t, manager.Store.SavePeer(ctx, accountID, peer3)) + + newSettings = account.Settings.Copy() + newSettings.Extra = &types.ExtraSettings{ + PeerApprovalEnabled: false, + } + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, newSettings) + require.NoError(t, err) + + accountPeers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range accountPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval after disabling peer approval", peer.ID) + } +} + func TestAccount_GetExpiredPeers(t *testing.T) { type test struct { name string diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index e9a1c87018c..69ea668adf2 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -127,7 +127,7 @@ type MockIntegratedValidator struct { ValidatePeerFunc func(_ context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) } -func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error { +func (a MockIntegratedValidator) ValidateExtraSettings(_ context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error { return nil } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index 26c338cb693..326fbfaf01d 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -10,7 +10,7 @@ import ( // IntegratedValidator interface exists to avoid the circle dependencies type IntegratedValidator interface { - ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, peers map[string]*nbpeer.Peer, userID string, accountID string) error + ValidateExtraSettings(ctx context.Context, newExtraSettings *types.ExtraSettings, oldExtraSettings *types.ExtraSettings, userID string, accountID string) error ValidatePeer(ctx context.Context, update *nbpeer.Peer, peer *nbpeer.Peer, userID string, accountID string, dnsDomain string, peersGroup []string, extraSettings *types.ExtraSettings) (*nbpeer.Peer, bool, error) PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 94b7fc1cc1e..cd3b993e182 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -413,6 +413,18 @@ func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerW return nil } +// ApproveAccountPeers marks all peers that currently require approval in the given account as approved. +func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) { + result := s.db.Model(&nbpeer.Peer{}). + Where("account_id = ? AND peer_status_requires_approval = ?", accountID, true). + Update("peer_status_requires_approval", false) + if result.Error != nil { + return 0, status.Errorf(status.Internal, "failed to approve pending account peers: %v", result.Error) + } + + return int(result.RowsAffected), nil +} + // SaveUsers saves the given list of users to the database. func (s *SqlStore) SaveUsers(ctx context.Context, users []*types.User) error { if len(users) == 0 { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index d40c4664c35..2e262391011 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -3717,3 +3717,80 @@ func TestSqlStore_GetPeersByGroupIDs(t *testing.T) { }) } } + +func TestSqlStore_ApproveAccountPeers(t *testing.T) { + runTestForAllEngines(t, "", func(t *testing.T, store Store) { + accountID := "test-account" + ctx := context.Background() + + account := newAccountWithId(ctx, accountID, "testuser", "example.com") + err := store.SaveAccount(ctx, account) + require.NoError(t, err) + + peers := []*nbpeer.Peer{ + { + ID: "peer1", + AccountID: accountID, + DNSLabel: "peer1.netbird.cloud", + Key: "peer1-key", + IP: net.ParseIP("100.64.0.1"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer2", + AccountID: accountID, + DNSLabel: "peer2.netbird.cloud", + Key: "peer2-key", + IP: net.ParseIP("100.64.0.2"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: true, + LastSeen: time.Now().UTC(), + }, + }, + { + ID: "peer3", + AccountID: accountID, + DNSLabel: "peer3.netbird.cloud", + Key: "peer3-key", + IP: net.ParseIP("100.64.0.3"), + Status: &nbpeer.PeerStatus{ + RequiresApproval: false, + LastSeen: time.Now().UTC(), + }, + }, + } + + for _, peer := range peers { + err = store.AddPeerToAccount(ctx, peer) + require.NoError(t, err) + } + + t.Run("approve all pending peers", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 2, count) + + allPeers, err := store.GetAccountPeers(ctx, LockingStrengthNone, accountID, "", "") + require.NoError(t, err) + + for _, peer := range allPeers { + assert.False(t, peer.Status.RequiresApproval, "peer %s should not require approval", peer.ID) + } + }) + + t.Run("no peers to approve", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, accountID) + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + + t.Run("non-existent account", func(t *testing.T) { + count, err := store.ApproveAccountPeers(ctx, "non-existent") + require.NoError(t, err) + assert.Equal(t, 0, count) + }) + }) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 007e2b73944..0ec7949f980 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -143,6 +143,7 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error + ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error)