diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 022ea774ce3..f9906acee1e 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" + "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" @@ -175,7 +176,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin dnsCache := &cache.DNSConfigCache{} dnsDomain := c.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) + peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -196,6 +197,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) + accountZones, err := c.repo.GetAccountZones(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account zones: %v", err) + return fmt.Errorf("failed to get account zones: %v", err) + } + for _, peer := range account.Peers { if !c.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -222,9 +229,9 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin var remotePeerNetworkMap *types.NetworkMap if c.experimentalNetworkMap(accountID) { - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) } c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) @@ -317,7 +324,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe dnsCache := &cache.DNSConfigCache{} dnsDomain := c.GetDNSDomain(account.Settings) - customZone := account.GetPeersCustomZone(ctx, dnsDomain) + peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() @@ -333,12 +340,18 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return err } + accountZones, err := c.repo.GetAccountZones(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account zones: %v", err) + return err + } + var remotePeerNetworkMap *types.NetworkMap if c.experimentalNetworkMap(accountId) { - remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) } else { - remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -432,7 +445,14 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr } log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) - customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + accountZones, err := c.repo.GetAccountZones(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account zones: %v", err) + return nil, nil, nil, 0, err + } + + dnsDomain := c.GetDNSDomain(account.Settings) + peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peer.ID, account.Peers) if err != nil { @@ -443,9 +463,9 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr var networkMap *types.NetworkMap if c.experimentalNetworkMap(accountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, c.accountManagerMetrics) + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), c.accountManagerMetrics) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] @@ -468,7 +488,8 @@ func (c *Controller) getPeerNetworkMapExp( accountId string, peerId string, validatedPeers map[string]struct{}, - customZone nbdns.CustomZone, + peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, metrics *telemetry.AccountManagerMetrics, ) *types.NetworkMap { account := c.getAccountFromHolderOrInit(accountId) @@ -478,7 +499,8 @@ func (c *Controller) getPeerNetworkMapExp( Network: &types.Network{}, } } - return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) + + return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics) } func (c *Controller) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { @@ -798,7 +820,15 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N if err != nil { return nil, err } - customZone := account.GetPeersCustomZone(ctx, c.GetDNSDomain(account.Settings)) + + accountZones, err := c.repo.GetAccountZones(ctx, account.Id) + if err != nil { + log.WithContext(ctx).Errorf("failed to get account zones: %v", err) + return nil, err + } + + dnsDomain := c.GetDNSDomain(account.Settings) + peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMaps(ctx, account.Id, peerID, account.Peers) if err != nil { @@ -809,9 +839,9 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N var networkMap *types.NetworkMap if c.experimentalNetworkMap(peer.AccountID) { - networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil) } else { - networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] diff --git a/management/internals/controllers/network_map/controller/repository.go b/management/internals/controllers/network_map/controller/repository.go index 3ed51a5c3bd..caef362cbdd 100644 --- a/management/internals/controllers/network_map/controller/repository.go +++ b/management/internals/controllers/network_map/controller/repository.go @@ -3,6 +3,7 @@ package controller import ( "context" + "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" @@ -14,6 +15,7 @@ type Repository interface { GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) + GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) } type repository struct { @@ -47,3 +49,7 @@ func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerID func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) { return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) } + +func (r *repository) GetAccountZones(ctx context.Context, accountID string) ([]*zones.Zone, error) { + return r.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) +} diff --git a/management/internals/modules/zones/interface.go b/management/internals/modules/zones/interface.go new file mode 100644 index 00000000000..8e23062308b --- /dev/null +++ b/management/internals/modules/zones/interface.go @@ -0,0 +1,13 @@ +package zones + +import ( + "context" +) + +type Manager interface { + GetAllZones(ctx context.Context, accountID, userID string) ([]*Zone, error) + GetZone(ctx context.Context, accountID, userID, zone string) (*Zone, error) + CreateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error) + UpdateZone(ctx context.Context, accountID, userID string, zone *Zone) (*Zone, error) + DeleteZone(ctx context.Context, accountID, userID, zoneID string) error +} diff --git a/management/internals/modules/zones/manager/api.go b/management/internals/modules/zones/manager/api.go new file mode 100644 index 00000000000..919d77d61d4 --- /dev/null +++ b/management/internals/modules/zones/manager/api.go @@ -0,0 +1,161 @@ +package manager + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/zones" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + manager zones.Manager +} + +func RegisterEndpoints(router *mux.Router, manager zones.Manager) { + h := &handler{ + manager: manager, + } + + router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS") +} + +func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + apiZones := make([]*api.Zone, 0, len(allZones)) + for _, zone := range allZones { + apiZones = append(apiZones, zone.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, apiZones) +} + +func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + var req api.PostApiDnsZonesJSONRequestBody + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + zone := new(zones.Zone) + zone.FromAPIRequest(&req) + + if err = zone.Validate(); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } + + createdZone, err := h.manager.CreateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse()) +} + +func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + zone, err := h.manager.GetZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse()) +} + +func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + var req api.PutApiDnsZonesZoneIdJSONRequestBody + if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + zone := new(zones.Zone) + zone.FromAPIRequest(&req) + zone.ID = zoneID + + if err = zone.Validate(); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } + + updatedZone, err := h.manager.UpdateZone(r.Context(), userAuth.AccountId, userAuth.UserId, zone) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse()) +} + +func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/internals/modules/zones/manager/manager.go b/management/internals/modules/zones/manager/manager.go new file mode 100644 index 00000000000..8548dd48cee --- /dev/null +++ b/management/internals/modules/zones/manager/manager.go @@ -0,0 +1,229 @@ +package manager + +import ( + "context" + "fmt" + + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type managerImpl struct { + store store.Store + accountManager account.Manager + permissionsManager permissions.Manager + dnsDomain string +} + +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager, dnsDomain string) zones.Manager { + return &managerImpl{ + store: store, + accountManager: accountManager, + permissionsManager: permissionsManager, + dnsDomain: dnsDomain, + } +} + +func (m *managerImpl) GetAllZones(ctx context.Context, accountID, userID string) ([]*zones.Zone, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetAccountZones(ctx, store.LockingStrengthNone, accountID) +} + +func (m *managerImpl) GetZone(ctx context.Context, accountID, userID, zoneID string) (*zones.Zone, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetZoneByID(ctx, store.LockingStrengthNone, accountID, zoneID) +} + +func (m *managerImpl) CreateZone(ctx context.Context, accountID, userID string, zone *zones.Zone) (*zones.Zone, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + if err = m.validateZoneDomainConflict(ctx, accountID, zone.Domain); err != nil { + return nil, err + } + + zone = zones.NewZone(accountID, zone.Name, zone.Domain, zone.Enabled, zone.EnableSearchDomain, zone.DistributionGroups) + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + existingZone, err := transaction.GetZoneByDomain(ctx, accountID, zone.Domain) + if err != nil { + if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { + return fmt.Errorf("failed to check existing zone: %w", err) + } + } + if existingZone != nil { + return status.Errorf(status.AlreadyExists, "zone with domain %s already exists", zone.Domain) + } + + for _, groupID := range zone.DistributionGroups { + _, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) + if err != nil { + return status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + } + + if err = transaction.CreateZone(ctx, zone); err != nil { + return fmt.Errorf("failed to create zone: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneCreated, zone.EventMeta()) + + return zone, nil +} + +func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string, updatedZone *zones.Zone) (*zones.Zone, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, updatedZone.ID) + if err != nil { + return nil, fmt.Errorf("failed to get zone: %w", err) + } + + if zone.Domain != updatedZone.Domain { + return nil, status.Errorf(status.InvalidArgument, "zone domain cannot be updated") + } + + zone.Name = updatedZone.Name + zone.Enabled = updatedZone.Enabled + zone.EnableSearchDomain = updatedZone.EnableSearchDomain + zone.DistributionGroups = updatedZone.DistributionGroups + + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + for _, groupID := range zone.DistributionGroups { + _, err = transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID) + if err != nil { + return status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + } + + if err = transaction.UpdateZone(ctx, zone); err != nil { + return fmt.Errorf("failed to update zone: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta()) + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return zone, nil +} + +func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + zone, err := m.store.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to get zone: %w", err) + } + + var eventsToStore []func() + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + records, err := transaction.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to get records: %w", err) + } + + err = transaction.DeleteZoneDNSRecords(ctx, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to delete zone dns records: %w", err) + } + + err = transaction.DeleteZone(ctx, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to delete zone: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + for _, record := range records { + eventsToStore = append(eventsToStore, func() { + meta := record.EventMeta(zone.ID, zone.Name) + m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordDeleted, meta) + }) + } + + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, zoneID, accountID, activity.DNSZoneDeleted, zone.EventMeta()) + }) + + return nil + }) + if err != nil { + return err + } + + for _, event := range eventsToStore { + event() + } + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +func (m *managerImpl) validateZoneDomainConflict(ctx context.Context, accountID, domain string) error { + if m.dnsDomain != "" && m.dnsDomain == domain { + return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain) + } + + settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return err + } + + if settings.DNSDomain != "" && settings.DNSDomain == domain { + return status.Errorf(status.InvalidArgument, "zone domain %s conflicts with peer DNS domain", domain) + } + + return nil +} diff --git a/management/internals/modules/zones/manager/manager_test.go b/management/internals/modules/zones/manager/manager_test.go new file mode 100644 index 00000000000..b45ec787417 --- /dev/null +++ b/management/internals/modules/zones/manager/manager_test.go @@ -0,0 +1,553 @@ +package manager + +import ( + "context" + "fmt" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "test-account-id" + testUserID = "test-user-id" + testZoneID = "test-zone-id" + testGroupID = "test-group-id" + testDNSDomain = "netbird.selfhosted" +) + +func setupTest(t *testing.T) (*managerImpl, store.Store, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { + t.Helper() + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + Name: "Test Group", + }, + }, + }) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + mockAccountManager := &mock_server.MockAccountManager{} + mockPermissionsManager := permissions.NewMockManager(ctrl) + + manager := &managerImpl{ + store: testStore, + accountManager: mockAccountManager, + permissionsManager: mockPermissionsManager, + dnsDomain: testDNSDomain, + } + + return manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup +} + +func TestManagerImpl_GetAllZones(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + zone1 := zones.NewZone(testAccountID, "Zone 1", "zone1.example.com", true, true, []string{testGroupID}) + err := testStore.CreateZone(ctx, zone1) + require.NoError(t, err) + + zone2 := zones.NewZone(testAccountID, "Zone 2", "zone2.example.com", false, false, []string{testGroupID}) + err = testStore.CreateZone(ctx, zone2) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(true, nil) + + result, err := manager.GetAllZones(ctx, testAccountID, testUserID) + require.NoError(t, err) + assert.Len(t, result, 2) + assert.Equal(t, zone1.ID, result[0].ID) + assert.Equal(t, zone2.ID, result[1].ID) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, nil) + + result, err := manager.GetAllZones(ctx, testAccountID, testUserID) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("permission validation error", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, status.Errorf(status.Internal, "permission check failed")) + + result, err := manager.GetAllZones(ctx, testAccountID, testUserID) + require.Error(t, err) + assert.Nil(t, result) + }) +} + +func TestManagerImpl_GetZone(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + zone := zones.NewZone(testAccountID, "Test Zone", "test.example.com", true, true, []string{testGroupID}) + err := testStore.CreateZone(ctx, zone) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(true, nil) + + result, err := manager.GetZone(ctx, testAccountID, testUserID, zone.ID) + require.NoError(t, err) + assert.Equal(t, zone.ID, result.ID) + assert.Equal(t, zone.Name, result.Name) + assert.Equal(t, zone.Domain, result.Domain) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, nil) + + result, err := manager.GetZone(ctx, testAccountID, testUserID, testZoneID) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) +} + +func TestManagerImpl_CreateZone(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, _, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputZone := &zones.Zone{ + Name: "New Zone", + Domain: "new.example.com", + Enabled: true, + EnableSearchDomain: true, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSZoneCreated, activityID) + } + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.ID) + assert.Equal(t, testAccountID, result.AccountID) + assert.Equal(t, inputZone.Name, result.Name) + assert.Equal(t, inputZone.Domain, result.Domain) + assert.Equal(t, inputZone.Enabled, result.Enabled) + assert.Equal(t, inputZone.EnableSearchDomain, result.EnableSearchDomain) + assert.Equal(t, inputZone.DistributionGroups, result.DistributionGroups) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputZone := &zones.Zone{ + Name: "New Zone", + Domain: "new.example.com", + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(false, nil) + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("invalid group", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputZone := &zones.Zone{ + Name: "New Zone", + Domain: "new.example.com", + DistributionGroups: []string{"invalid-group"}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("duplicate domain", func(t *testing.T) { + manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingZone := zones.NewZone(testAccountID, "Existing Zone", "duplicate.example.com", true, false, []string{testGroupID}) + err := testStore.CreateZone(ctx, existingZone) + require.NoError(t, err) + + inputZone := &zones.Zone{ + Name: "New Zone", + Domain: "duplicate.example.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "zone with domain duplicate.example.com already exists") + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.AlreadyExists, s.Type()) + }) + + t.Run("peer DNS domain conflict", func(t *testing.T) { + manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + account, err := testStore.GetAccount(ctx, testAccountID) + require.NoError(t, err) + account.Settings.DNSDomain = "peers.example.com" + err = testStore.SaveAccount(ctx, account) + require.NoError(t, err) + + inputZone := &zones.Zone{ + Name: "Test Zone", + Domain: "peers.example.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "zone domain peers.example.com conflicts with peer DNS domain") + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.InvalidArgument, s.Type()) + }) + + t.Run("default DNS domain conflict", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputZone := &zones.Zone{ + Name: "Test Zone", + Domain: testDNSDomain, + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateZone(ctx, testAccountID, testUserID, inputZone) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), fmt.Sprintf("zone domain %s conflicts with peer DNS domain", testDNSDomain)) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.InvalidArgument, s.Type()) + }) +} + +func TestManagerImpl_UpdateZone(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingZone := zones.NewZone(testAccountID, "Old Name", "example.com", false, false, []string{testGroupID}) + err := testStore.CreateZone(ctx, existingZone) + require.NoError(t, err) + + updatedZone := &zones.Zone{ + ID: existingZone.ID, + Name: "Updated Name", + Domain: "example.com", + Enabled: true, + EnableSearchDomain: true, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + storeEventCalled := false + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + storeEventCalled = true + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, existingZone.ID, targetID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSZoneUpdated, activityID) + } + + result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, updatedZone.Name, result.Name) + assert.Equal(t, updatedZone.Enabled, result.Enabled) + assert.Equal(t, updatedZone.EnableSearchDomain, result.EnableSearchDomain) + assert.True(t, storeEventCalled, "StoreEvent should have been called") + }) + + t.Run("domain change not allowed", func(t *testing.T) { + manager, testStore, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingZone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID}) + err := testStore.CreateZone(ctx, existingZone) + require.NoError(t, err) + + updatedZone := &zones.Zone{ + ID: existingZone.ID, + Name: "Test Zone", + Domain: "different.com", + Enabled: true, + EnableSearchDomain: true, + DistributionGroups: []string{testGroupID}, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "zone domain cannot be updated") + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.InvalidArgument, s.Type()) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + updatedZone := &zones.Zone{ + ID: testZoneID, + Name: "Updated Name", + Domain: "example.com", + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(false, nil) + + result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("zone not found", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + updatedZone := &zones.Zone{ + ID: "non-existent-zone", + Name: "Updated Name", + Domain: "example.com", + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + result, err := manager.UpdateZone(ctx, testAccountID, testUserID, updatedZone) + require.Error(t, err) + assert.Nil(t, result) + }) +} + +func TestManagerImpl_DeleteZone(t *testing.T) { + ctx := context.Background() + + t.Run("success with records", func(t *testing.T) { + manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID}) + err := testStore.CreateZone(ctx, zone) + require.NoError(t, err) + + record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = testStore.CreateDNSRecord(ctx, record1) + require.NoError(t, err) + + record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300) + err = testStore.CreateDNSRecord(ctx, record2) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(true, nil) + + storeEventCallCount := 0 + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + storeEventCallCount++ + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, testAccountID, accountID) + } + + err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID) + require.NoError(t, err) + assert.Equal(t, 3, storeEventCallCount) + + _, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID) + require.Error(t, err) + + zoneRecords, err := testStore.GetZoneDNSRecords(ctx, store.LockingStrengthNone, testAccountID, zone.ID) + require.NoError(t, err) + assert.Empty(t, zoneRecords) + }) + + t.Run("success without records", func(t *testing.T) { + manager, testStore, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID}) + err := testStore.CreateZone(ctx, zone) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(true, nil) + + storeEventCalled := false + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + storeEventCalled = true + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, zone.ID, targetID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSZoneDeleted, activityID) + } + + err = manager.DeleteZone(ctx, testAccountID, testUserID, zone.ID) + require.NoError(t, err) + assert.True(t, storeEventCalled, "StoreEvent should have been called") + + _, err = testStore.GetZoneByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID) + require.Error(t, err) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(false, nil) + + err := manager.DeleteZone(ctx, testAccountID, testUserID, testZoneID) + require.Error(t, err) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("zone not found", func(t *testing.T) { + manager, _, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(true, nil) + + err := manager.DeleteZone(ctx, testAccountID, testUserID, "non-existent-zone") + require.Error(t, err) + }) +} diff --git a/management/internals/modules/zones/records/interface.go b/management/internals/modules/zones/records/interface.go new file mode 100644 index 00000000000..ceb8c531864 --- /dev/null +++ b/management/internals/modules/zones/records/interface.go @@ -0,0 +1,13 @@ +package records + +import ( + "context" +) + +type Manager interface { + GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*Record, error) + GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*Record, error) + CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error) + UpdateRecord(ctx context.Context, accountID, userID, zoneID string, record *Record) (*Record, error) + DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error +} diff --git a/management/internals/modules/zones/records/manager/api.go b/management/internals/modules/zones/records/manager/api.go new file mode 100644 index 00000000000..f8ecfef7d55 --- /dev/null +++ b/management/internals/modules/zones/records/manager/api.go @@ -0,0 +1,191 @@ +package manager + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" + "github.com/netbirdio/netbird/shared/management/status" +) + +type handler struct { + manager records.Manager +} + +func RegisterEndpoints(router *mux.Router, manager records.Manager) { + h := &handler{ + manager: manager, + } + + router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS") +} + +func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + allRecords, err := h.manager.GetAllRecords(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + apiRecords := make([]*api.DNSRecord, 0, len(allRecords)) + for _, record := range allRecords { + apiRecords = append(apiRecords, record.ToAPIResponse()) + } + + util.WriteJSONObject(r.Context(), w, apiRecords) +} + +func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + record := new(records.Record) + record.FromAPIRequest(&req) + + if err = record.Validate(); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } + + createdRecord, err := h.manager.CreateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse()) +} + +func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + recordID := mux.Vars(r)["recordId"] + if recordID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w) + return + } + + record, err := h.manager.GetRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, record.ToAPIResponse()) +} + +func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + recordID := mux.Vars(r)["recordId"] + if recordID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w) + return + } + + var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody + if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) + return + } + + record := new(records.Record) + record.FromAPIRequest(&req) + record.ID = recordID + + if err = record.Validate(); err != nil { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) + return + } + + updatedRecord, err := h.manager.UpdateRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, record) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse()) +} + +func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + zoneID := mux.Vars(r)["zoneId"] + if zoneID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) + return + } + + recordID := mux.Vars(r)["recordId"] + if recordID == "" { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "record ID is required"), w) + return + } + + if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { + util.WriteError(r.Context(), err, w) + return + } + + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) +} diff --git a/management/internals/modules/zones/records/manager/manager.go b/management/internals/modules/zones/records/manager/manager.go new file mode 100644 index 00000000000..5374a2ef2a3 --- /dev/null +++ b/management/internals/modules/zones/records/manager/manager.go @@ -0,0 +1,236 @@ +package manager + +import ( + "context" + "fmt" + "strings" + + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/shared/management/status" +) + +type managerImpl struct { + store store.Store + accountManager account.Manager + permissionsManager permissions.Manager +} + +func NewManager(store store.Store, accountManager account.Manager, permissionsManager permissions.Manager) records.Manager { + return &managerImpl{ + store: store, + accountManager: accountManager, + permissionsManager: permissionsManager, + } +} + +func (m *managerImpl) GetAllRecords(ctx context.Context, accountID, userID, zoneID string) ([]*records.Record, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetZoneDNSRecords(ctx, store.LockingStrengthNone, accountID, zoneID) +} + +func (m *managerImpl) GetRecord(ctx context.Context, accountID, userID, zoneID, recordID string) (*records.Record, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + return m.store.GetDNSRecordByID(ctx, store.LockingStrengthNone, accountID, zoneID, recordID) +} + +func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneID string, record *records.Record) (*records.Record, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Create) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + var zone *zones.Zone + + record = records.NewRecord(accountID, zoneID, record.Name, record.Type, record.Content, record.TTL) + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to get zone: %w", err) + } + + err = validateRecordConflicts(ctx, transaction, zone, record) + if err != nil { + return err + } + + if err = transaction.CreateDNSRecord(ctx, record); err != nil { + return fmt.Errorf("failed to create dns record: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + meta := record.EventMeta(zone.ID, zone.Name) + m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta) + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return record, nil +} + +func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneID string, updatedRecord *records.Record) (*records.Record, error) { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Update) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !ok { + return nil, status.NewPermissionDeniedError() + } + + var zone *zones.Zone + var record *records.Record + + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to get zone: %w", err) + } + + record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, updatedRecord.ID) + if err != nil { + return fmt.Errorf("failed to get record: %w", err) + } + + hasChanges := record.Name != updatedRecord.Name || record.Type != updatedRecord.Type || record.Content != updatedRecord.Content + + record.Name = updatedRecord.Name + record.Type = updatedRecord.Type + record.Content = updatedRecord.Content + record.TTL = updatedRecord.TTL + + if hasChanges { + if err = validateRecordConflicts(ctx, transaction, zone, record); err != nil { + return err + } + } + + if err = transaction.UpdateDNSRecord(ctx, record); err != nil { + return fmt.Errorf("failed to update dns record: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return nil, err + } + + meta := record.EventMeta(zone.ID, zone.Name) + m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta) + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return record, nil +} + +func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneID, recordID string) error { + ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Dns, operations.Delete) + if err != nil { + return status.NewPermissionValidationError(err) + } + if !ok { + return status.NewPermissionDeniedError() + } + + var record *records.Record + var zone *zones.Zone + + err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + zone, err = transaction.GetZoneByID(ctx, store.LockingStrengthUpdate, accountID, zoneID) + if err != nil { + return fmt.Errorf("failed to get zone: %w", err) + } + + record, err = transaction.GetDNSRecordByID(ctx, store.LockingStrengthUpdate, accountID, zoneID, recordID) + if err != nil { + return fmt.Errorf("failed to get record: %w", err) + } + + err = transaction.DeleteDNSRecord(ctx, accountID, zoneID, recordID) + if err != nil { + return fmt.Errorf("failed to delete dns record: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return nil + }) + if err != nil { + return err + } + + meta := record.EventMeta(zone.ID, zone.Name) + m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta) + + go m.accountManager.UpdateAccountPeers(ctx, accountID) + + return nil +} + +// validateRecordConflicts checks for duplicate records and CNAME conflicts +func validateRecordConflicts(ctx context.Context, transaction store.Store, zone *zones.Zone, record *records.Record) error { + if record.Name != zone.Domain && !strings.HasSuffix(record.Name, "."+zone.Domain) { + return status.Errorf(status.InvalidArgument, "record name does not belong to zone") + } + + existingRecords, err := transaction.GetZoneDNSRecordsByName(ctx, store.LockingStrengthNone, zone.AccountID, zone.ID, record.Name) + if err != nil { + return fmt.Errorf("failed to check existing records: %w", err) + } + + for _, existing := range existingRecords { + if existing.ID == record.ID { + continue + } + + if existing.Type == record.Type && existing.Content == record.Content { + return status.Errorf(status.AlreadyExists, "identical record already exists") + } + + if record.Type == records.RecordTypeCNAME || existing.Type == records.RecordTypeCNAME { + return status.Errorf(status.InvalidArgument, + "An A, AAAA, or CNAME record with name %s already exists", record.Name) + } + } + + return nil +} diff --git a/management/internals/modules/zones/records/manager/manager_test.go b/management/internals/modules/zones/records/manager/manager_test.go new file mode 100644 index 00000000000..0a962e0f4ac --- /dev/null +++ b/management/internals/modules/zones/records/manager/manager_test.go @@ -0,0 +1,573 @@ +package manager + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/mock_server" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/management/server/store" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/status" +) + +const ( + testAccountID = "test-account-id" + testUserID = "test-user-id" + testRecordID = "test-record-id" + testGroupID = "test-group-id" +) + +func setupTest(t *testing.T) (*managerImpl, store.Store, *zones.Zone, *mock_server.MockAccountManager, *permissions.MockManager, *gomock.Controller, func()) { + t.Helper() + + ctx := context.Background() + testStore, cleanup, err := store.NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + + err = testStore.SaveAccount(ctx, &types.Account{ + Id: testAccountID, + Groups: map[string]*types.Group{ + testGroupID: { + ID: testGroupID, + Name: "Test Group", + }, + }, + }) + require.NoError(t, err) + + zone := zones.NewZone(testAccountID, "Test Zone", "example.com", true, true, []string{testGroupID}) + err = testStore.CreateZone(ctx, zone) + require.NoError(t, err) + + ctrl := gomock.NewController(t) + mockAccountManager := &mock_server.MockAccountManager{} + mockPermissionsManager := permissions.NewMockManager(ctrl) + + manager := &managerImpl{ + store: testStore, + accountManager: mockAccountManager, + permissionsManager: mockPermissionsManager, + } + + return manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup +} + +func TestManagerImpl_GetAllRecords(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, record1) + require.NoError(t, err) + + record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300) + err = testStore.CreateDNSRecord(ctx, record2) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(true, nil) + + result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) + require.NoError(t, err) + assert.Len(t, result, 2) + assert.Equal(t, record1.ID, result[0].ID) + assert.Equal(t, record2.ID, result[1].ID) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, nil) + + result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("permission validation error", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, status.Errorf(status.Internal, "permission check failed")) + + result, err := manager.GetAllRecords(ctx, testAccountID, testUserID, zone.ID) + require.Error(t, err) + assert.Nil(t, result) + }) +} + +func TestManagerImpl_GetRecord(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, record) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(true, nil) + + result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, record.ID) + require.NoError(t, err) + assert.Equal(t, record.ID, result.ID) + assert.Equal(t, record.Name, result.Name) + assert.Equal(t, record.Type, result.Type) + assert.Equal(t, record.Content, result.Content) + assert.Equal(t, record.TTL, result.TTL) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Read). + Return(false, nil) + + result, err := manager.GetRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) +} + +func TestManagerImpl_CreateRecord(t *testing.T) { + ctx := context.Background() + + t.Run("success - A record", func(t *testing.T) { + manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputRecord := &records.Record{ + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSRecordCreated, activityID) + } + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.ID) + assert.Equal(t, testAccountID, result.AccountID) + assert.Equal(t, zone.ID, result.ZoneID) + assert.Equal(t, inputRecord.Name, result.Name) + assert.Equal(t, inputRecord.Type, result.Type) + assert.Equal(t, inputRecord.Content, result.Content) + assert.Equal(t, inputRecord.TTL, result.TTL) + }) + + t.Run("success - AAAA record", func(t *testing.T) { + manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputRecord := &records.Record{ + Name: "ipv6.example.com", + Type: records.RecordTypeAAAA, + Content: "2001:db8::1", + TTL: 600, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSRecordCreated, activityID) + } + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, inputRecord.Type, result.Type) + assert.Equal(t, inputRecord.Content, result.Content) + }) + + t.Run("success - CNAME record", func(t *testing.T) { + manager, _, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputRecord := &records.Record{ + Name: "www.example.com", + Type: records.RecordTypeCNAME, + Content: "example.com", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSRecordCreated, activityID) + } + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, inputRecord.Type, result.Type) + assert.Equal(t, inputRecord.Content, result.Content) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputRecord := &records.Record{ + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(false, nil) + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("record name not in zone", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + inputRecord := &records.Record{ + Name: "api.different.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "does not belong to zone") + }) + + t.Run("duplicate record", func(t *testing.T) { + manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, existingRecord) + require.NoError(t, err) + + inputRecord := &records.Record{ + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "identical record already exists") + }) + + t.Run("CNAME conflict with existing A record", func(t *testing.T) { + manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, existingRecord) + require.NoError(t, err) + + inputRecord := &records.Record{ + Name: "api.example.com", + Type: records.RecordTypeCNAME, + Content: "example.com", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Create). + Return(true, nil) + + result, err := manager.CreateRecord(ctx, testAccountID, testUserID, zone.ID, inputRecord) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "already exists") + }) +} + +func TestManagerImpl_UpdateRecord(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, existingRecord) + require.NoError(t, err) + + updatedRecord := &records.Record{ + ID: existingRecord.ID, + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.100", // Changed IP + TTL: 600, // Changed TTL + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + storeEventCalled := false + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + storeEventCalled = true + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, existingRecord.ID, targetID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSRecordUpdated, activityID) + } + + result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, updatedRecord.Content, result.Content) + assert.Equal(t, updatedRecord.TTL, result.TTL) + assert.True(t, storeEventCalled, "StoreEvent should have been called") + }) + + t.Run("update only TTL - no validation", func(t *testing.T) { + manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + existingRecord := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, existingRecord) + require.NoError(t, err) + + updatedRecord := &records.Record{ + ID: existingRecord.ID, + Name: existingRecord.Name, + Type: existingRecord.Type, + Content: existingRecord.Content, + TTL: 600, // Only TTL changed + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + // Event should be stored + } + + result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) + require.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, 600, result.TTL) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + updatedRecord := &records.Record{ + ID: testRecordID, + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.100", + TTL: 600, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(false, nil) + + result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) + require.Error(t, err) + assert.Nil(t, result) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("record not found", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + updatedRecord := &records.Record{ + ID: "non-existent-record", + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.100", + TTL: 600, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) + require.Error(t, err) + assert.Nil(t, result) + }) + + t.Run("update creates duplicate", func(t *testing.T) { + manager, testStore, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + record1 := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, record1) + require.NoError(t, err) + + record2 := records.NewRecord(testAccountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.2", 300) + err = testStore.CreateDNSRecord(ctx, record2) + require.NoError(t, err) + + updatedRecord := &records.Record{ + ID: record2.ID, + Name: "api.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + } + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Update). + Return(true, nil) + + result, err := manager.UpdateRecord(ctx, testAccountID, testUserID, zone.ID, updatedRecord) + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "identical record already exists") + }) +} + +func TestManagerImpl_DeleteRecord(t *testing.T) { + ctx := context.Background() + + t.Run("success", func(t *testing.T) { + manager, testStore, zone, mockAccountManager, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + record := records.NewRecord(testAccountID, zone.ID, "api.example.com", records.RecordTypeA, "192.168.1.1", 300) + err := testStore.CreateDNSRecord(ctx, record) + require.NoError(t, err) + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(true, nil) + + storeEventCalled := false + mockAccountManager.StoreEventFunc = func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) { + storeEventCalled = true + assert.Equal(t, testUserID, initiatorID) + assert.Equal(t, record.ID, targetID) + assert.Equal(t, testAccountID, accountID) + assert.Equal(t, activity.DNSRecordDeleted, activityID) + } + + err = manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, record.ID) + require.NoError(t, err) + assert.True(t, storeEventCalled, "StoreEvent should have been called") + + _, err = testStore.GetDNSRecordByID(ctx, store.LockingStrengthNone, testAccountID, zone.ID, record.ID) + require.Error(t, err) + }) + + t.Run("permission denied", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(false, nil) + + err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, testRecordID) + require.Error(t, err) + s, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, status.PermissionDenied, s.Type()) + }) + + t.Run("record not found", func(t *testing.T) { + manager, _, zone, _, mockPermissionsManager, ctrl, cleanup := setupTest(t) + defer cleanup() + defer ctrl.Finish() + + mockPermissionsManager.EXPECT(). + ValidateUserPermissions(ctx, testAccountID, testUserID, modules.Dns, operations.Delete). + Return(true, nil) + + err := manager.DeleteRecord(ctx, testAccountID, testUserID, zone.ID, "non-existent-record") + require.Error(t, err) + }) +} diff --git a/management/internals/modules/zones/records/record.go b/management/internals/modules/zones/records/record.go new file mode 100644 index 00000000000..97e48d688ea --- /dev/null +++ b/management/internals/modules/zones/records/record.go @@ -0,0 +1,129 @@ +package records + +import ( + "errors" + "net" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +type RecordType string + +const ( + RecordTypeA RecordType = "A" + RecordTypeAAAA RecordType = "AAAA" + RecordTypeCNAME RecordType = "CNAME" +) + +type Record struct { + AccountID string `gorm:"index"` + ZoneID string `gorm:"index"` + ID string `gorm:"primaryKey"` + Name string + Type RecordType + Content string + TTL int +} + +func NewRecord(accountID, zoneID, name string, recordType RecordType, content string, ttl int) *Record { + return &Record{ + ID: xid.New().String(), + AccountID: accountID, + ZoneID: zoneID, + Name: name, + Type: recordType, + Content: content, + TTL: ttl, + } +} + +func (r *Record) ToAPIResponse() *api.DNSRecord { + recordType := api.DNSRecordType(r.Type) + return &api.DNSRecord{ + Id: r.ID, + Name: r.Name, + Type: recordType, + Content: r.Content, + Ttl: r.TTL, + } +} + +func (r *Record) FromAPIRequest(req *api.DNSRecordRequest) { + r.Name = req.Name + r.Type = RecordType(req.Type) + r.Content = req.Content + r.TTL = req.Ttl +} + +func (r *Record) Validate() error { + if r.Name == "" { + return errors.New("record name is required") + } + + if !util.IsValidDomain(r.Name) { + return errors.New("invalid record name format") + } + + if r.Type == "" { + return errors.New("record type is required") + } + + switch r.Type { + case RecordTypeA: + if err := validateIPv4(r.Content); err != nil { + return err + } + case RecordTypeAAAA: + if err := validateIPv6(r.Content); err != nil { + return err + } + case RecordTypeCNAME: + if !util.IsValidDomain(r.Content) { + return errors.New("invalid CNAME record format") + } + default: + return errors.New("invalid record type, must be A, AAAA, or CNAME") + } + + if r.TTL < 0 { + return errors.New("TTL cannot be negative") + } + + return nil +} + +func (r *Record) EventMeta(zoneID, zoneName string) map[string]any { + return map[string]any{ + "name": r.Name, + "type": string(r.Type), + "content": r.Content, + "ttl": r.TTL, + "zone_id": zoneID, + "zone_name": zoneName, + } +} + +func validateIPv4(content string) error { + if content == "" { + return errors.New("A record is required") + } + ip := net.ParseIP(content) + if ip == nil || ip.To4() == nil { + return errors.New("A record must be a valid IPv4 address") + } + return nil +} + +func validateIPv6(content string) error { + if content == "" { + return errors.New("AAAA record is required") + } + ip := net.ParseIP(content) + if ip == nil || ip.To4() != nil { + return errors.New("AAAA record must be a valid IPv6 address") + } + return nil +} diff --git a/management/internals/modules/zones/zone.go b/management/internals/modules/zones/zone.go new file mode 100644 index 00000000000..27adac1ac3b --- /dev/null +++ b/management/internals/modules/zones/zone.go @@ -0,0 +1,89 @@ +package zones + +import ( + "errors" + + "github.com/rs/xid" + + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + "github.com/netbirdio/netbird/management/server/util" + "github.com/netbirdio/netbird/shared/management/http/api" +) + +type Zone struct { + ID string `gorm:"primaryKey"` + AccountID string `gorm:"index"` + Name string + Domain string + Enabled bool + EnableSearchDomain bool + DistributionGroups []string `gorm:"serializer:json"` + Records []*records.Record `gorm:"foreignKey:ZoneID;references:ID"` +} + +func NewZone(accountID, name, domain string, enabled, enableSearchDomain bool, distributionGroups []string) *Zone { + return &Zone{ + ID: xid.New().String(), + AccountID: accountID, + Name: name, + Domain: domain, + Enabled: enabled, + EnableSearchDomain: enableSearchDomain, + DistributionGroups: distributionGroups, + } +} + +func (z *Zone) ToAPIResponse() *api.Zone { + apiRecords := make([]api.DNSRecord, 0, len(z.Records)) + for _, record := range z.Records { + if apiRecord := record.ToAPIResponse(); apiRecord != nil { + apiRecords = append(apiRecords, *apiRecord) + } + } + + return &api.Zone{ + DistributionGroups: z.DistributionGroups, + Domain: z.Domain, + EnableSearchDomain: z.EnableSearchDomain, + Enabled: z.Enabled, + Id: z.ID, + Name: z.Name, + Records: apiRecords, + } +} + +func (z *Zone) FromAPIRequest(req *api.ZoneRequest) { + z.Name = req.Name + z.Domain = req.Domain + z.EnableSearchDomain = req.EnableSearchDomain + z.DistributionGroups = req.DistributionGroups + + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + z.Enabled = enabled +} + +func (z *Zone) Validate() error { + if z.Name == "" { + return errors.New("zone name is required") + } + if len(z.Name) > 255 { + return errors.New("zone name exceeds maximum length of 255 characters") + } + + if !util.IsValidDomain(z.Domain) { + return errors.New("invalid zone domain format") + } + + if len(z.DistributionGroups) == 0 { + return errors.New("at least one distribution group is required") + } + + return nil +} + +func (z *Zone) EventMeta() map[string]any { + return map[string]any{"name": z.Name, "domain": z.Domain} +} diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 37788e80eec..154ed2a10d2 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -93,7 +93,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.NetworkMapController()) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index af9ca5f2df8..764e4e7247f 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -8,6 +8,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/modules/peers" + "github.com/netbirdio/netbird/management/internals/modules/zones" + zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/geolocation" @@ -128,3 +132,15 @@ func (s *BaseServer) NetworksManager() networks.Manager { return networks.NewManager(s.Store(), s.PermissionsManager(), s.ResourcesManager(), s.RoutesManager(), s.AccountManager()) }) } + +func (s *BaseServer) ZonesManager() zones.Manager { + return Create(s, func() zones.Manager { + return zonesManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager(), s.DNSDomain()) + }) +} + +func (s *BaseServer) RecordsManager() records.Manager { + return Create(s, func() records.Manager { + return recordsManager.NewManager(s.Store(), s.AccountManager(), s.PermissionsManager()) + }) +} diff --git a/management/internals/shared/grpc/conversion.go b/management/internals/shared/grpc/conversion.go index 2b15fe4b8ff..cdf6153706a 100644 --- a/management/internals/shared/grpc/conversion.go +++ b/management/internals/shared/grpc/conversion.go @@ -329,8 +329,9 @@ func shouldUsePortRange(rule *proto.FirewallRule) bool { // Helper function to convert nbdns.CustomZone to proto.CustomZone func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone { protoZone := &proto.CustomZone{ - Domain: zone.Domain, - Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + Domain: zone.Domain, + Records: make([]*proto.SimpleRecord, 0, len(zone.Records)), + SearchDomainDisabled: zone.SearchDomainDisabled, } for _, record := range zone.Records { protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{ diff --git a/management/server/account.go b/management/server/account.go index a9becc4b6af..b2d46f1cf0d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -295,7 +295,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return err } - if err = am.validateSettingsUpdate(ctx, newSettings, oldSettings, userID, accountID); err != nil { + if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { return err } @@ -385,7 +385,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return newSettings, nil } -func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, newSettings, oldSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -399,6 +399,18 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, new return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) } + if newSettings.DNSDomain != oldSettings.DNSDomain && newSettings.DNSDomain != "" { + existingZone, err := transaction.GetZoneByDomain(ctx, accountID, newSettings.DNSDomain) + if err != nil { + if sErr, ok := status.FromError(err); !ok || sErr.Type() != status.NotFound { + return fmt.Errorf("failed to check existing zone: %w", err) + } + } + if existingZone != nil { + return status.Errorf(status.InvalidArgument, "peer DNS domain %s conflicts with existing custom DNS zone", newSettings.DNSDomain) + } + } + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, userID, accountID) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 7f125e3a0c8..b0e96578cde 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -27,6 +27,7 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" + "github.com/netbirdio/netbird/management/internals/modules/zones" "github.com/netbirdio/netbird/management/internals/server/config" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" @@ -397,7 +398,7 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) { } customZone := account.GetPeersCustomZone(context.Background(), "netbird.io") - networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, nil, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) assert.Len(t, networkMap.Peers, len(testCase.expectedPeers)) assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers)) } @@ -1676,7 +1677,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { }, } - routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}) + routes := account.GetRoutesToSync(context.Background(), "peer-2", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-3"}}, account.GetPeerGroups("peer-2")) assert.Len(t, routes, 2) routeIDs := make(map[route.ID]struct{}, 2) @@ -1686,7 +1687,7 @@ func TestAccount_GetRoutesToSync(t *testing.T) { assert.Contains(t, routeIDs, route.ID("route-2")) assert.Contains(t, routeIDs, route.ID("route-3")) - emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}) + emptyRoutes := account.GetRoutesToSync(context.Background(), "peer-3", []*nbpeer.Peer{{Key: "peer-1"}, {Key: "peer-2"}}, account.GetPeerGroups("peer-3")) assert.Len(t, emptyRoutes, 0) } @@ -2095,6 +2096,35 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerApproval(t *testing.T) } } +func TestDefaultAccountManager_UpdateAccountSettings_DNSDomainConflict(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") + require.NoError(t, err, "unable to create an account") + + ctx := context.Background() + err = manager.Store.CreateZone(ctx, &zones.Zone{ + ID: "test-zone-id", + AccountID: accountID, + Name: "Test Zone", + Domain: "custom.example.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{}, + }) + require.NoError(t, err, "unable to create custom DNS zone") + + _, err = manager.UpdateAccountSettings(ctx, accountID, userID, &types.Settings{ + DNSDomain: "custom.example.com", + PeerLoginExpiration: time.Hour, + PeerLoginExpirationEnabled: false, + Extra: &types.ExtraSettings{}, + }) + require.Error(t, err, "expecting to fail when DNS domain conflicts with custom zone") + assert.Contains(t, err.Error(), "conflicts with existing custom DNS zone") +} + func TestAccount_GetExpiredPeers(t *testing.T) { type test struct { name string diff --git a/management/server/activity/codes.go b/management/server/activity/codes.go index 2e3be1ef5a0..cd7c4b19f1f 100644 --- a/management/server/activity/codes.go +++ b/management/server/activity/codes.go @@ -181,6 +181,14 @@ const ( UserRejected Activity = 90 UserCreated Activity = 91 + DNSZoneCreated Activity = 92 + DNSZoneUpdated Activity = 93 + DNSZoneDeleted Activity = 94 + + DNSRecordCreated Activity = 95 + DNSRecordUpdated Activity = 96 + DNSRecordDeleted Activity = 97 + AccountDeleted Activity = 99999 ) @@ -290,6 +298,14 @@ var activityMap = map[Activity]Code{ UserApproved: {"User approved", "user.approve"}, UserRejected: {"User rejected", "user.reject"}, UserCreated: {"User created", "user.create"}, + + DNSZoneCreated: {"DNS zone created", "dns.zone.create"}, + DNSZoneUpdated: {"DNS zone updated", "dns.zone.update"}, + DNSZoneDeleted: {"DNS zone deleted", "dns.zone.delete"}, + + DNSRecordCreated: {"DNS zone record created", "dns.zone.record.create"}, + DNSRecordUpdated: {"DNS zone record updated", "dns.zone.record.update"}, + DNSRecordDeleted: {"DNS zone record deleted", "dns.zone.record.delete"}, } // StringCode returns a string code of the activity diff --git a/management/server/http/handler.go b/management/server/http/handler.go index b7c6c113c9c..d77e4155763 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -14,7 +14,10 @@ import ( "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/controllers/network_map" - + "github.com/netbirdio/netbird/management/internals/modules/zones" + zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" + recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/settings" @@ -66,6 +69,8 @@ func NewAPIHandler( permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, + zManager zones.Manager, + rManager records.Manager, networkMapController network_map.Controller, ) (http.Handler, error) { @@ -134,6 +139,8 @@ func NewAPIHandler( dns.AddEndpoints(accountManager, router) events.AddEndpoints(accountManager, router) networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) + zonesManager.RegisterEndpoints(router, zManager) + recordsManager.RegisterEndpoints(router, rManager) return rootRouter, nil } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index f531f0cdb24..b4fe13e2d5d 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -10,6 +10,7 @@ import ( "github.com/gorilla/mux" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" @@ -297,9 +298,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } dnsDomain := h.networkMapController.GetDNSDomain(account.Settings) - - customZone := account.GetPeersCustomZone(r.Context(), dnsDomain) - netMap := account.GetPeerNetworkMap(r.Context(), peerID, customZone, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + netMap := account.GetPeerNetworkMap(r.Context(), peerID, dns.CustomZone{}, nil, validPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index e8513feb568..54dc0315f5c 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/management-integrations/integrations" + zonesManager "github.com/netbirdio/netbird/management/internals/modules/zones/manager" + recordsManager "github.com/netbirdio/netbird/management/internals/modules/zones/records/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/controllers/network_map" @@ -93,8 +95,10 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee routersManagerMock := routers.NewManagerMock() groupsManagerMock := groups.NewManagerMock() peersManager := peers.NewManager(store, permissionsManager) + customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") + zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, networkMapController) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2b8981b97c1..4756f660e03 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -27,6 +27,8 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -112,6 +114,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, + &zones.Zone{}, &records.Record{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -4075,3 +4078,184 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro return peers, nil } + +func (s *SqlStore) CreateZone(ctx context.Context, zone *zones.Zone) error { + result := s.db.Create(zone) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create zone to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to create zone to store") + } + + return nil +} + +func (s *SqlStore) UpdateZone(ctx context.Context, zone *zones.Zone) error { + result := s.db.Select("*").Save(zone) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update zone to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to update zone to store") + } + + return nil +} + +func (s *SqlStore) DeleteZone(ctx context.Context, accountID, zoneID string) error { + result := s.db.Delete(&zones.Zone{}, accountAndIDQueryCondition, accountID, zoneID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete zone from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete zone from store") + } + + if result.RowsAffected == 0 { + return status.NewZoneNotFoundError(zoneID) + } + + return nil +} + +func (s *SqlStore) GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var zone *zones.Zone + result := tx.Preload("Records").Take(&zone, accountAndIDQueryCondition, accountID, zoneID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewZoneNotFoundError(zoneID) + } + + log.WithContext(ctx).Errorf("failed to get zone from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get zone from store") + } + + return zone, nil +} + +func (s *SqlStore) GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) { + var zone *zones.Zone + result := s.db.Where("account_id = ? AND domain = ?", accountID, domain).First(&zone) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewZoneNotFoundError(domain) + } + + log.WithContext(ctx).Errorf("failed to get zone by domain from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get zone by domain from store") + } + + return zone, nil +} + +func (s *SqlStore) GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var zones []*zones.Zone + result := tx.Preload("Records").Find(&zones, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get zones from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get zones from store") + } + + return zones, nil +} + +func (s *SqlStore) CreateDNSRecord(ctx context.Context, record *records.Record) error { + result := s.db.Create(record) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to create dns record to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to create dns record to store") + } + + return nil +} + +func (s *SqlStore) UpdateDNSRecord(ctx context.Context, record *records.Record) error { + result := s.db.Select("*").Save(record) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update dns record to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to update dns record to store") + } + + return nil +} + +func (s *SqlStore) DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error { + result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete dns record from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete dns record from store") + } + + if result.RowsAffected == 0 { + return status.NewDNSRecordNotFoundError(recordID) + } + + return nil +} + +func (s *SqlStore) GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var record *records.Record + result := tx.Where("account_id = ? AND zone_id = ? AND id = ?", accountID, zoneID, recordID).Take(&record) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewDNSRecordNotFoundError(recordID) + } + + log.WithContext(ctx).Errorf("failed to get dns record from store: %v", result.Error) + return nil, status.Errorf(status.Internal, "failed to get dns record from store") + } + + return record, nil +} + +func (s *SqlStore) GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var recordsList []*records.Record + result := tx.Where("account_id = ? AND zone_id = ?", accountID, zoneID).Find(&recordsList) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get zone dns records from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get zone dns records from store") + } + + return recordsList, nil +} + +func (s *SqlStore) GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) { + tx := s.db + if lockStrength != LockingStrengthNone { + tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) + } + + var recordsList []*records.Record + result := tx.Where("account_id = ? AND zone_id = ? AND name = ?", accountID, zoneID, name).Find(&recordsList) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get zone dns records by name from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get zone dns records by name from store") + } + + return recordsList, nil +} + +func (s *SqlStore) DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error { + result := s.db.Delete(&records.Record{}, "account_id = ? AND zone_id = ?", accountID, zoneID) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to delete zone dns records from store: %v", result.Error) + return status.Errorf(status.Internal, "failed to delete zone dns records from store") + } + + return nil +} diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 2e262391011..7d32bcdc053 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -3794,3 +3796,476 @@ func TestSqlStore_ApproveAccountPeers(t *testing.T) { }) }) } + +func TestSqlStore_CreateZone(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.NoError(t, err) + require.NotNil(t, savedZone) + assert.Equal(t, zone.ID, savedZone.ID) + assert.Equal(t, zone.Name, savedZone.Name) + assert.Equal(t, zone.Domain, savedZone.Domain) + assert.Equal(t, zone.Enabled, savedZone.Enabled) + assert.Equal(t, zone.EnableSearchDomain, savedZone.EnableSearchDomain) + assert.Equal(t, zone.DistributionGroups, savedZone.DistributionGroups) +} + +func TestSqlStore_GetZoneByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + zoneID string + expectError bool + }{ + { + name: "retrieve existing zone", + accountID: accountID, + zoneID: zone.ID, + expectError: false, + }, + { + name: "retrieve non-existing zone", + accountID: accountID, + zoneID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty zone ID", + accountID: accountID, + zoneID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + savedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, savedZone) + } else { + require.NoError(t, err) + require.NotNil(t, savedZone) + assert.Equal(t, tt.zoneID, savedZone.ID) + } + }) + } +} + +func TestSqlStore_GetAccountZones(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone1 := zones.NewZone(accountID, "Zone 1", "example1.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone1) + require.NoError(t, err) + + zone2 := zones.NewZone(accountID, "Zone 2", "example2.com", true, true, []string{"group1", "group2"}) + err = store.CreateZone(context.Background(), zone2) + require.NoError(t, err) + + allZones, err := store.GetAccountZones(context.Background(), LockingStrengthNone, accountID) + require.NoError(t, err) + require.NotNil(t, allZones) + assert.GreaterOrEqual(t, len(allZones), 2) + + zoneIDs := make(map[string]bool) + for _, z := range allZones { + zoneIDs[z.ID] = true + } + assert.True(t, zoneIDs[zone1.ID]) + assert.True(t, zoneIDs[zone2.ID]) +} + +func TestSqlStore_GetZoneByDomain(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + otherAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3c" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + domain string + expectError bool + errorType status.Type + }{ + { + name: "retrieve existing zone by domain", + accountID: accountID, + domain: "example.com", + expectError: false, + }, + { + name: "retrieve non-existing zone domain", + accountID: accountID, + domain: "non-existing.com", + expectError: true, + errorType: status.NotFound, + }, + { + name: "retrieve with empty domain", + accountID: accountID, + domain: "", + expectError: true, + errorType: status.NotFound, + }, + { + name: "retrieve with different account ID", + accountID: otherAccountID, + domain: "example.com", + expectError: true, + errorType: status.NotFound, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + savedZone, err := store.GetZoneByDomain(context.Background(), tt.accountID, tt.domain) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, tt.errorType, sErr.Type()) + require.Nil(t, savedZone) + } else { + require.NoError(t, err) + require.NotNil(t, savedZone) + assert.Equal(t, tt.domain, savedZone.Domain) + assert.Equal(t, zone.ID, savedZone.ID) + assert.Equal(t, zone.Name, savedZone.Name) + } + }) + } +} + +func TestSqlStore_UpdateZone(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + zone.Name = "Updated Zone" + zone.Domain = "updated.com" + zone.Enabled = false + zone.EnableSearchDomain = true + zone.DistributionGroups = []string{"group2", "group3"} + + err = store.UpdateZone(context.Background(), zone) + require.NoError(t, err) + + updatedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.NoError(t, err) + require.NotNil(t, updatedZone) + assert.Equal(t, "Updated Zone", updatedZone.Name) + assert.Equal(t, "updated.com", updatedZone.Domain) + assert.False(t, updatedZone.Enabled) + assert.True(t, updatedZone.EnableSearchDomain) + assert.Equal(t, []string{"group2", "group3"}, updatedZone.DistributionGroups) +} + +func TestSqlStore_DeleteZone(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + err = store.DeleteZone(context.Background(), accountID, zone.ID) + require.NoError(t, err) + + deletedZone, err := store.GetZoneByID(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.Error(t, err) + require.Nil(t, deletedZone) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) +} + +func TestSqlStore_CreateDNSRecord(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + + err = store.CreateDNSRecord(context.Background(), record) + require.NoError(t, err) + + savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID) + require.NoError(t, err) + require.NotNil(t, savedRecord) + assert.Equal(t, record.ID, savedRecord.ID) + assert.Equal(t, record.Name, savedRecord.Name) + assert.Equal(t, record.Type, savedRecord.Type) + assert.Equal(t, record.Content, savedRecord.Content) + assert.Equal(t, record.TTL, savedRecord.TTL) + assert.Equal(t, zone.ID, savedRecord.ZoneID) +} + +func TestSqlStore_GetDNSRecordByID(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), record) + require.NoError(t, err) + + tests := []struct { + name string + accountID string + zoneID string + recordID string + expectError bool + }{ + { + name: "retrieve existing record", + accountID: accountID, + zoneID: zone.ID, + recordID: record.ID, + expectError: false, + }, + { + name: "retrieve non-existing record", + accountID: accountID, + zoneID: zone.ID, + recordID: "non-existing", + expectError: true, + }, + { + name: "retrieve with empty record ID", + accountID: accountID, + zoneID: zone.ID, + recordID: "", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + savedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, tt.accountID, tt.zoneID, tt.recordID) + if tt.expectError { + require.Error(t, err) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) + require.Nil(t, savedRecord) + } else { + require.NoError(t, err) + require.NotNil(t, savedRecord) + assert.Equal(t, tt.recordID, savedRecord.ID) + } + }) + } +} + +func TestSqlStore_GetZoneDNSRecords(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + recordA := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), recordA) + require.NoError(t, err) + + recordAAAA := records.NewRecord(accountID, zone.ID, "ipv6.example.com", records.RecordTypeAAAA, "2001:db8::1", 300) + err = store.CreateDNSRecord(context.Background(), recordAAAA) + require.NoError(t, err) + + recordCNAME := records.NewRecord(accountID, zone.ID, "alias.example.com", records.RecordTypeCNAME, "www.example.com", 300) + err = store.CreateDNSRecord(context.Background(), recordCNAME) + require.NoError(t, err) + + allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.NoError(t, err) + require.NotNil(t, allRecords) + assert.Equal(t, 3, len(allRecords)) + + recordIDs := make(map[string]bool) + for _, r := range allRecords { + recordIDs[r.ID] = true + } + assert.True(t, recordIDs[recordA.ID]) + assert.True(t, recordIDs[recordAAAA.ID]) + assert.True(t, recordIDs[recordCNAME.ID]) +} + +func TestSqlStore_GetZoneDNSRecordsByName(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), record1) + require.NoError(t, err) + + record2 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeAAAA, "2001:db8::1", 300) + err = store.CreateDNSRecord(context.Background(), record2) + require.NoError(t, err) + + record3 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600) + err = store.CreateDNSRecord(context.Background(), record3) + require.NoError(t, err) + + recordsByName, err := store.GetZoneDNSRecordsByName(context.Background(), LockingStrengthNone, accountID, zone.ID, "www.example.com") + require.NoError(t, err) + require.NotNil(t, recordsByName) + assert.Equal(t, 2, len(recordsByName)) + + for _, r := range recordsByName { + assert.Equal(t, "www.example.com", r.Name) + } +} + +func TestSqlStore_UpdateDNSRecord(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), record) + require.NoError(t, err) + + record.Name = "api.example.com" + record.Content = "192.168.1.100" + record.TTL = 600 + + err = store.UpdateDNSRecord(context.Background(), record) + require.NoError(t, err) + + updatedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID) + require.NoError(t, err) + require.NotNil(t, updatedRecord) + assert.Equal(t, "api.example.com", updatedRecord.Name) + assert.Equal(t, "192.168.1.100", updatedRecord.Content) + assert.Equal(t, 600, updatedRecord.TTL) +} + +func TestSqlStore_DeleteDNSRecord(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), record) + require.NoError(t, err) + + err = store.DeleteDNSRecord(context.Background(), accountID, zone.ID, record.ID) + require.NoError(t, err) + + deletedRecord, err := store.GetDNSRecordByID(context.Background(), LockingStrengthNone, accountID, zone.ID, record.ID) + require.Error(t, err) + require.Nil(t, deletedRecord) + sErr, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, sErr.Type(), status.NotFound) +} + +func TestSqlStore_DeleteZoneDNSRecords(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + zone := zones.NewZone(accountID, "Test Zone", "example.com", true, false, []string{"group1"}) + err = store.CreateZone(context.Background(), zone) + require.NoError(t, err) + + record1 := records.NewRecord(accountID, zone.ID, "www.example.com", records.RecordTypeA, "192.168.1.1", 300) + err = store.CreateDNSRecord(context.Background(), record1) + require.NoError(t, err) + + record2 := records.NewRecord(accountID, zone.ID, "mail.example.com", records.RecordTypeA, "192.168.1.2", 600) + err = store.CreateDNSRecord(context.Background(), record2) + require.NoError(t, err) + + allRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.NoError(t, err) + assert.Equal(t, 2, len(allRecords)) + + err = store.DeleteZoneDNSRecords(context.Background(), accountID, zone.ID) + require.NoError(t, err) + + remainingRecords, err := store.GetZoneDNSRecords(context.Background(), LockingStrengthNone, accountID, zone.ID) + require.NoError(t, err) + assert.Equal(t, 0, len(remainingRecords)) +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 0ec7949f980..3904c26656a 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -23,6 +23,8 @@ import ( "gorm.io/gorm" "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/testutil" "github.com/netbirdio/netbird/management/server/types" @@ -204,6 +206,21 @@ type Store interface { MarkAccountPrimary(ctx context.Context, accountID string) error UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error GetPolicyRulesByResourceID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) ([]*types.PolicyRule, error) + + CreateZone(ctx context.Context, zone *zones.Zone) error + UpdateZone(ctx context.Context, zone *zones.Zone) error + DeleteZone(ctx context.Context, accountID, zoneID string) error + GetZoneByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) (*zones.Zone, error) + GetZoneByDomain(ctx context.Context, accountID, domain string) (*zones.Zone, error) + GetAccountZones(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*zones.Zone, error) + + CreateDNSRecord(ctx context.Context, record *records.Record) error + UpdateDNSRecord(ctx context.Context, record *records.Record) error + DeleteDNSRecord(ctx context.Context, accountID, zoneID, recordID string) error + GetDNSRecordByID(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, recordID string) (*records.Record, error) + GetZoneDNSRecords(ctx context.Context, lockStrength LockingStrength, accountID, zoneID string) ([]*records.Record, error) + GetZoneDNSRecordsByName(ctx context.Context, lockStrength LockingStrength, accountID, zoneID, name string) ([]*records.Record, error) + DeleteZoneDNSRecords(ctx context.Context, accountID, zoneID string) error } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index 9e86d89366c..bd0a6cad7f8 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -17,6 +17,8 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -147,17 +149,16 @@ func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { // GetRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. -func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer) []*route.Route { +func (a *Account) GetRoutesToSync(ctx context.Context, peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route { routes, peerDisabledRoutes := a.getRoutingPeerRoutes(ctx, peerID) peerRoutesMembership := make(LookupMap) for _, r := range append(routes, peerDisabledRoutes...) { peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} } - groupListMap := a.GetPeerGroups(peerID) for _, peer := range aclPeers { activeRoutes, _ := a.getRoutingPeerRoutes(ctx, peer.ID) - groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, groupListMap) + groupFilteredRoutes := a.filterRoutesByGroups(activeRoutes, peerGroups) filteredRoutes := a.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) routes = append(routes, filteredRoutes...) } @@ -271,6 +272,7 @@ func (a *Account) GetPeerNetworkMap( ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, validatedPeersMap map[string]struct{}, resourcePolicies map[string][]*Policy, routers map[string]map[string]*routerTypes.NetworkRouter, @@ -290,6 +292,8 @@ func (a *Account) GetPeerNetworkMap( } } + peerGroups := a.GetPeerGroups(peerID) + aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap) // exclude expired peers var peersToConnect []*nbpeer.Peer @@ -303,7 +307,7 @@ func (a *Account) GetPeerNetworkMap( peersToConnect = append(peersToConnect, p) } - routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect) + routesUpdate := a.GetRoutesToSync(ctx, peerID, peersToConnect, peerGroups) routesFirewallRules := a.GetPeerRoutesFirewallRules(ctx, peerID, validatedPeersMap) isRouter, networkResourcesRoutes, sourcePeers := a.GetNetworkResourcesRoutesToSync(ctx, peerID, resourcePolicies, routers) var networkResourcesFirewallRules []*RouteFirewallRule @@ -319,6 +323,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone + if peersCustomZone.Domain != "" { records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ @@ -326,6 +331,10 @@ func (a *Account) GetPeerNetworkMap( Records: records, }) } + + filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) + zones = append(zones, filteredAccountZones...) + dnsUpdate.CustomZones = zones dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID) } @@ -1774,3 +1783,65 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p return filteredRecords } + +// filterPeerAppliedZones filters account zones based on the peer's group membership +func filterPeerAppliedZones(ctx context.Context, accountZones []*zones.Zone, peerGroups LookupMap) []nbdns.CustomZone { + var customZones []nbdns.CustomZone + + if len(peerGroups) == 0 { + return customZones + } + + for _, zone := range accountZones { + if !zone.Enabled || len(zone.Records) == 0 { + continue + } + + hasAccess := false + for _, distGroupID := range zone.DistributionGroups { + if _, found := peerGroups[distGroupID]; found { + hasAccess = true + break + } + } + + if !hasAccess { + continue + } + + simpleRecords := make([]nbdns.SimpleRecord, 0, len(zone.Records)) + for _, record := range zone.Records { + var recordType int + rData := record.Content + + switch record.Type { + case records.RecordTypeA: + recordType = int(dns.TypeA) + case records.RecordTypeAAAA: + recordType = int(dns.TypeAAAA) + case records.RecordTypeCNAME: + recordType = int(dns.TypeCNAME) + rData = dns.Fqdn(record.Content) + default: + log.WithContext(ctx).Warnf("unknown DNS record type %s for record %s", record.Type, record.ID) + continue + } + + simpleRecords = append(simpleRecords, nbdns.SimpleRecord{ + Name: dns.Fqdn(record.Name), + Type: recordType, + Class: nbdns.DefaultClass, + TTL: record.TTL, + RData: rData, + }) + } + + customZones = append(customZones, nbdns.CustomZone{ + Domain: dns.Fqdn(zone.Domain), + Records: simpleRecords, + SearchDomainDisabled: !zone.EnableSearchDomain, + }) + } + + return customZones +} diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index f9aa6a1c22a..6e3d4aa5dfe 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" + "github.com/netbirdio/netbird/management/internals/modules/zones/records" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -1238,3 +1240,515 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }) } } + +func Test_filterPeerAppliedZones(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + accountZones []*zones.Zone + peerGroups LookupMap + expected []nbdns.CustomZone + }{ + { + name: "empty peer groups returns empty custom zones", + accountZones: []*zones.Zone{}, + peerGroups: LookupMap{}, + expected: []nbdns.CustomZone{}, + }, + { + name: "peer has access to zone with A record", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "example.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.example.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "example.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.example.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "peer has access to zone with search domain enabled", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "internal.local", + Enabled: true, + EnableSearchDomain: true, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "api.internal.local", + Type: records.RecordTypeA, + Content: "10.0.0.1", + TTL: 600, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "internal.local.", + Records: []nbdns.SimpleRecord{ + { + Name: "api.internal.local.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 600, + RData: "10.0.0.1", + }, + }, + SearchDomainDisabled: false, + }, + }, + }, + { + name: "peer has no access to zone", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "private.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group2"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "secret.private.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{}, + }, + { + name: "disabled zone is filtered out", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "disabled.com", + Enabled: false, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.disabled.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{}, + }, + { + name: "zone with no records is filtered out", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "empty.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{}, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{}, + }, + { + name: "peer has access via multiple groups", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "multi.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1", "group2", "group3"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.multi.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group2": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "multi.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.multi.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "multiple zones with mixed access", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "allowed.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.allowed.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + { + ID: "zone2", + Domain: "denied.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group2"}, + Records: []*records.Record{ + { + ID: "record2", + Name: "www.denied.com", + Type: records.RecordTypeA, + Content: "192.168.1.2", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "allowed.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.allowed.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "zone with multiple record types", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "mixed.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.mixed.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + { + ID: "record2", + Name: "ipv6.mixed.com", + Type: records.RecordTypeAAAA, + Content: "2001:db8::1", + TTL: 600, + }, + { + ID: "record3", + Name: "alias.mixed.com", + Type: records.RecordTypeCNAME, + Content: "www.mixed.com", + TTL: 900, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "mixed.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.mixed.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + { + Name: "ipv6.mixed.com.", + Type: int(dns.TypeAAAA), + Class: nbdns.DefaultClass, + TTL: 600, + RData: "2001:db8::1", + }, + { + Name: "alias.mixed.com.", + Type: int(dns.TypeCNAME), + Class: nbdns.DefaultClass, + TTL: 900, + RData: "www.mixed.com.", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "multiple zones both accessible", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "first.com", + Enabled: true, + EnableSearchDomain: true, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.first.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + { + ID: "zone2", + Domain: "second.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record2", + Name: "www.second.com", + Type: records.RecordTypeA, + Content: "192.168.1.2", + TTL: 600, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "first.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.first.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + }, + SearchDomainDisabled: false, + }, + { + Domain: "second.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.second.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 600, + RData: "192.168.1.2", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "zone with multiple records of same type", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "multi-a.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.multi-a.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + { + ID: "record2", + Name: "www.multi-a.com", + Type: records.RecordTypeA, + Content: "192.168.1.2", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "multi-a.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.multi-a.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + { + Name: "www.multi-a.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.2", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + { + name: "peer in multiple groups accessing different zones", + accountZones: []*zones.Zone{ + { + ID: "zone1", + Domain: "zone1.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + Records: []*records.Record{ + { + ID: "record1", + Name: "www.zone1.com", + Type: records.RecordTypeA, + Content: "192.168.1.1", + TTL: 300, + }, + }, + }, + { + ID: "zone2", + Domain: "zone2.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group2"}, + Records: []*records.Record{ + { + ID: "record2", + Name: "www.zone2.com", + Type: records.RecordTypeA, + Content: "192.168.1.2", + TTL: 300, + }, + }, + }, + }, + peerGroups: LookupMap{"group1": struct{}{}, "group2": struct{}{}}, + expected: []nbdns.CustomZone{ + { + Domain: "zone1.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.zone1.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.1", + }, + }, + SearchDomainDisabled: true, + }, + { + Domain: "zone2.com.", + Records: []nbdns.SimpleRecord{ + { + Name: "www.zone2.com.", + Type: int(dns.TypeA), + Class: nbdns.DefaultClass, + TTL: 300, + RData: "192.168.1.2", + }, + }, + SearchDomainDisabled: true, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterPeerAppliedZones(ctx, tt.accountZones, tt.peerGroups) + require.Equal(t, len(tt.expected), len(result), "number of custom zones should match") + + for i, expectedZone := range tt.expected { + assert.Equal(t, expectedZone.Domain, result[i].Domain, "domain should match") + assert.Equal(t, expectedZone.SearchDomainDisabled, result[i].SearchDomainDisabled, "search domain disabled flag should match") + assert.Equal(t, len(expectedZone.Records), len(result[i].Records), "number of records should match") + + for j, expectedRecord := range expectedZone.Records { + assert.Equal(t, expectedRecord.Name, result[i].Records[j].Name, "record name should match") + assert.Equal(t, expectedRecord.Type, result[i].Records[j].Type, "record type should match") + assert.Equal(t, expectedRecord.Class, result[i].Records[j].Class, "record class should match") + assert.Equal(t, expectedRecord.TTL, result[i].Records[j].TTL, "record TTL should match") + assert.Equal(t, expectedRecord.RData, result[i].Records[j].RData, "record RData should match") + } + } + }) + } +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go index c1099726fd4..bc75acb4b39 100644 --- a/management/server/types/networkmap.go +++ b/management/server/types/networkmap.go @@ -4,6 +4,7 @@ import ( "context" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/telemetry" ) @@ -25,11 +26,12 @@ func (a *Account) GetPeerNetworkMapExp( ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + accountZones []*zones.Zone, validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { a.initNetworkMapBuilder(validatedPeers) - return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, accountZones, validatedPeers, metrics) } func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go index d85aaabb212..a89b47c3a13 100644 --- a/management/server/types/networkmap_golden_test.go +++ b/management/server/types/networkmap_golden_test.go @@ -69,7 +69,7 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil) normalizeAndSortNetworkMap(networkMap) @@ -105,7 +105,7 @@ func TestGetPeerNetworkMap_Golden_New(t *testing.T) { } builder := types.NewNetworkMapBuilder(account, validatedPeersMap) - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) normalizeAndSortNetworkMap(networkMap) @@ -141,7 +141,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) { b.Run("old builder", func(b *testing.B) { for range b.N { for _, peerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil) } } }) @@ -150,7 +150,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) { for range b.N { builder := types.NewNetworkMapBuilder(account, validatedPeersMap) for _, peerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, nil, validatedPeersMap, nil) } } }) @@ -201,7 +201,7 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil) normalizeAndSortNetworkMap(networkMap) @@ -269,7 +269,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { err := builder.OnPeerAddedIncremental(newPeerID) require.NoError(t, err, "error adding peer to cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) normalizeAndSortNetworkMap(networkMap) @@ -320,7 +320,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil) } } }) @@ -330,7 +330,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { for i := 0; i < b.N; i++ { _ = builder.OnPeerAddedIncremental(newPeerID) for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) } } }) @@ -395,7 +395,7 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil) normalizeAndSortNetworkMap(networkMap) @@ -476,7 +476,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { err := builder.OnPeerAddedIncremental(newRouterID) require.NoError(t, err, "error adding router to cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) normalizeAndSortNetworkMap(networkMap) @@ -550,7 +550,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { b.Run("old builder after add", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil) } } }) @@ -560,7 +560,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { for i := 0; i < b.N; i++ { _ = builder.OnPeerAddedIncremental(newRouterID) for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) } } }) @@ -604,7 +604,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil) normalizeAndSortNetworkMap(networkMap) @@ -665,7 +665,7 @@ func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { err := builder.OnPeerDeleted(deletedPeerID) require.NoError(t, err, "error deleting peer from cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) normalizeAndSortNetworkMap(networkMap) @@ -730,7 +730,7 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() - networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, resourcePolicies, routers, nil) normalizeAndSortNetworkMap(networkMap) @@ -797,7 +797,7 @@ func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { err := builder.OnPeerDeleted(deletedRouterID) require.NoError(t, err, "error deleting routing peer from cache") - networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) normalizeAndSortNetworkMap(networkMap) @@ -847,7 +847,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { b.Run("old builder after delete", func(b *testing.B) { for i := 0; i < b.N; i++ { for _, testingPeerID := range peerIDs { - _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil, nil, nil) } } }) @@ -857,7 +857,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { for i := 0; i < b.N; i++ { _ = builder.OnPeerDeleted(deletedPeerID) for _, testingPeerID := range peerIDs { - _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, nil, validatedPeersMap, nil) } } }) diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go index 5790f16466a..765a51eb22f 100644 --- a/management/server/types/networkmapbuilder.go +++ b/management/server/types/networkmapbuilder.go @@ -14,6 +14,7 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/internals/modules/zones" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -936,7 +937,7 @@ func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { } func (b *NetworkMapBuilder) GetPeerNetworkMap( - ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() @@ -958,7 +959,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap( return &NetworkMap{Network: account.Network.Copy()} } - nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + nm := b.assembleNetworkMap(ctx, account, peer, aclView, routesView, dnsConfig, peersCustomZone, accountZones, validatedPeers) if metrics != nil { objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) @@ -975,8 +976,8 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap( } func (b *NetworkMapBuilder) assembleNetworkMap( - account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, - dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, + ctx context.Context, account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, peersCustomZone nbdns.CustomZone, accountZones []*zones.Zone, validatedPeers map[string]struct{}, ) *NetworkMap { var peersToConnect []*nbpeer.Peer @@ -1024,13 +1025,26 @@ func (b *NetworkMapBuilder) assembleNetworkMap( } finalDNSConfig := *dnsConfig - if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + if finalDNSConfig.ServiceEnable { var zones []nbdns.CustomZone - records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) - zones = append(zones, nbdns.CustomZone{ - Domain: customZone.Domain, - Records: records, - }) + + peerGroupsSlice := b.cache.peerToGroups[peer.ID] + peerGroups := make(LookupMap, len(peerGroupsSlice)) + for _, groupID := range peerGroupsSlice { + peerGroups[groupID] = struct{}{} + } + + if peersCustomZone.Domain != "" { + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: peersCustomZone.Domain, + Records: records, + }) + } + + filteredAccountZones := filterPeerAppliedZones(ctx, accountZones, peerGroups) + zones = append(zones, filteredAccountZones...) + finalDNSConfig.CustomZones = zones } diff --git a/management/server/util/util.go b/management/server/util/util.go index 617484274a1..eea6a72b025 100644 --- a/management/server/util/util.go +++ b/management/server/util/util.go @@ -1,5 +1,9 @@ package util +import "regexp" + +var domainRegex = regexp.MustCompile(`^(\*\.)?([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,}$`) + // Difference returns the elements in `a` that aren't in `b`. func Difference(a, b []string) []string { mb := make(map[string]struct{}, len(b)) @@ -50,3 +54,10 @@ func contains[T comparableObject[T]](slice []T, element T) bool { } return false } + +func IsValidDomain(domain string) bool { + if domain == "" { + return false + } + return domainRegex.MatchString(domain) +} diff --git a/shared/management/client/rest/client.go b/shared/management/client/rest/client.go index 2a5de5bbc49..c5a8b9f2536 100644 --- a/shared/management/client/rest/client.go +++ b/shared/management/client/rest/client.go @@ -58,9 +58,13 @@ type Client struct { Routes *RoutesAPI // DNS NetBird DNS APIs - // see more: https://docs.netbird.io/api/resources/routes + // see more: https://docs.netbird.io/api/resources/dns DNS *DNSAPI + // DNSZones NetBird DNS Zones APIs + // see more: https://docs.netbird.io/api/resources/dns-zones + DNSZones *DNSZonesAPI + // GeoLocation NetBird Geo Location APIs // see more: https://docs.netbird.io/api/resources/geo-locations GeoLocation *GeoLocationAPI @@ -112,6 +116,7 @@ func (c *Client) initialize() { c.Networks = &NetworksAPI{c} c.Routes = &RoutesAPI{c} c.DNS = &DNSAPI{c} + c.DNSZones = &DNSZonesAPI{c} c.GeoLocation = &GeoLocationAPI{c} c.Events = &EventsAPI{c} } diff --git a/shared/management/client/rest/dns_zones.go b/shared/management/client/rest/dns_zones.go new file mode 100644 index 00000000000..6ee7d336edf --- /dev/null +++ b/shared/management/client/rest/dns_zones.go @@ -0,0 +1,170 @@ +package rest + +import ( + "bytes" + "context" + "encoding/json" + + "github.com/netbirdio/netbird/shared/management/http/api" +) + +// DNSZonesAPI APIs for DNS Zones Management, do not use directly +type DNSZonesAPI struct { + c *Client +} + +// ListZones list all DNS zones +// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-zones +func (a *DNSZonesAPI) ListZones(ctx context.Context) ([]api.Zone, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.Zone](resp) + return ret, err +} + +// GetZone get DNS zone info +// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-zone +func (a *DNSZonesAPI) GetZone(ctx context.Context, zoneID string) (*api.Zone, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Zone](resp) + return &ret, err +} + +// CreateZone create new DNS zone +// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-zone +func (a *DNSZonesAPI) CreateZone(ctx context.Context, request api.PostApiDnsZonesJSONRequestBody) (*api.Zone, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Zone](resp) + return &ret, err +} + +// UpdateZone update DNS zone info +// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-zone +func (a *DNSZonesAPI) UpdateZone(ctx context.Context, zoneID string, request api.PutApiDnsZonesZoneIdJSONRequestBody) (*api.Zone, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.Zone](resp) + return &ret, err +} + +// DeleteZone delete DNS zone +// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-zone +func (a *DNSZonesAPI) DeleteZone(ctx context.Context, zoneID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + + return nil +} + +// ListRecords list all DNS records in a zone +// See more: https://docs.netbird.io/api/resources/dns-zones#list-all-dns-records +func (a *DNSZonesAPI) ListRecords(ctx context.Context, zoneID string) ([]api.DNSRecord, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records", nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[[]api.DNSRecord](resp) + return ret, err +} + +// GetRecord get DNS record info +// See more: https://docs.netbird.io/api/resources/dns-zones#retrieve-a-dns-record +func (a *DNSZonesAPI) GetRecord(ctx context.Context, zoneID, recordID string) (*api.DNSRecord, error) { + resp, err := a.c.NewRequest(ctx, "GET", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.DNSRecord](resp) + return &ret, err +} + +// CreateRecord create new DNS record in a zone +// See more: https://docs.netbird.io/api/resources/dns-zones#create-a-dns-record +func (a *DNSZonesAPI) CreateRecord(ctx context.Context, zoneID string, request api.PostApiDnsZonesZoneIdRecordsJSONRequestBody) (*api.DNSRecord, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "POST", "/api/dns/zones/"+zoneID+"/records", bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.DNSRecord](resp) + return &ret, err +} + +// UpdateRecord update DNS record info +// See more: https://docs.netbird.io/api/resources/dns-zones#update-a-dns-record +func (a *DNSZonesAPI) UpdateRecord(ctx context.Context, zoneID, recordID string, request api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody) (*api.DNSRecord, error) { + requestBytes, err := json.Marshal(request) + if err != nil { + return nil, err + } + resp, err := a.c.NewRequest(ctx, "PUT", "/api/dns/zones/"+zoneID+"/records/"+recordID, bytes.NewReader(requestBytes), nil) + if err != nil { + return nil, err + } + if resp.Body != nil { + defer resp.Body.Close() + } + ret, err := parseResponse[api.DNSRecord](resp) + return &ret, err +} + +// DeleteRecord delete DNS record +// See more: https://docs.netbird.io/api/resources/dns-zones#delete-a-dns-record +func (a *DNSZonesAPI) DeleteRecord(ctx context.Context, zoneID, recordID string) error { + resp, err := a.c.NewRequest(ctx, "DELETE", "/api/dns/zones/"+zoneID+"/records/"+recordID, nil, nil) + if err != nil { + return err + } + if resp.Body != nil { + defer resp.Body.Close() + } + + return nil +} diff --git a/shared/management/client/rest/dns_zones_test.go b/shared/management/client/rest/dns_zones_test.go new file mode 100644 index 00000000000..c04a3ea5792 --- /dev/null +++ b/shared/management/client/rest/dns_zones_test.go @@ -0,0 +1,460 @@ +//go:build integration +// +build integration + +package rest_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/shared/management/client/rest" + "github.com/netbirdio/netbird/shared/management/http/api" + "github.com/netbirdio/netbird/shared/management/http/util" +) + +var ( + testZone = api.Zone{ + Id: "zone123", + Name: "test-zone", + Domain: "example.com", + Enabled: true, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + } + + testDNSRecord = api.DNSRecord{ + Id: "record123", + Name: "www", + Content: "192.168.1.1", + Type: api.DNSRecordTypeA, + Ttl: 300, + } +) + +func TestDNSZone_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.Zone{testZone}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.ListZones(context.Background()) + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testZone, ret[0]) + }) +} + +func TestDNSZone_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "No", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.ListZones(context.Background()) + assert.Error(t, err) + assert.Equal(t, "No", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestDNSZone_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testZone) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.GetZone(context.Background(), "zone123") + require.NoError(t, err) + assert.Equal(t, testZone, *ret) + }) +} + +func TestDNSZone_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.GetZone(context.Background(), "zone123") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestDNSZone_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.PostApiDnsZonesJSONRequestBody + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "test-zone", req.Name) + assert.Equal(t, "example.com", req.Domain) + retBytes, _ := json.Marshal(testZone) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + enabled := true + ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{ + Name: "test-zone", + Domain: "example.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + }) + require.NoError(t, err) + assert.Equal(t, testZone, *ret) + }) +} + +func TestDNSZone_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.CreateZone(context.Background(), api.PostApiDnsZonesJSONRequestBody{ + Name: "test-zone", + Domain: "example.com", + }) + assert.Error(t, err) + assert.Equal(t, "Invalid request", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestDNSZone_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.PutApiDnsZonesZoneIdJSONRequestBody + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "updated-zone", req.Name) + retBytes, _ := json.Marshal(testZone) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + enabled := true + ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{ + Name: "updated-zone", + Domain: "example.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{"group1"}, + }) + require.NoError(t, err) + assert.Equal(t, testZone, *ret) + }) +} + +func TestDNSZone_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid request", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.UpdateZone(context.Background(), "zone123", api.PutApiDnsZonesZoneIdJSONRequestBody{ + Name: "updated-zone", + Domain: "example.com", + }) + assert.Error(t, err) + assert.Equal(t, "Invalid request", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestDNSZone_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.DNSZones.DeleteZone(context.Background(), "zone123") + require.NoError(t, err) + }) +} + +func TestDNSZone_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.DNSZones.DeleteZone(context.Background(), "zone123") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestDNSRecord_List_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal([]api.DNSRecord{testDNSRecord}) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.ListRecords(context.Background(), "zone123") + require.NoError(t, err) + assert.Len(t, ret, 1) + assert.Equal(t, testDNSRecord, ret[0]) + }) +} + +func TestDNSRecord_List_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Zone not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.ListRecords(context.Background(), "zone123") + assert.Error(t, err) + assert.Equal(t, "Zone not found", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestDNSRecord_Get_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "GET", r.Method) + retBytes, _ := json.Marshal(testDNSRecord) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123") + require.NoError(t, err) + assert.Equal(t, testDNSRecord, *ret) + }) +} + +func TestDNSRecord_Get_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.GetRecord(context.Background(), "zone123", "record123") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + assert.Empty(t, ret) + }) +} + +func TestDNSRecord_Create_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "POST", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.PostApiDnsZonesZoneIdRecordsJSONRequestBody + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "www", req.Name) + assert.Equal(t, "192.168.1.1", req.Content) + assert.Equal(t, api.DNSRecordTypeA, req.Type) + retBytes, _ := json.Marshal(testDNSRecord) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "www", + Content: "192.168.1.1", + Type: api.DNSRecordTypeA, + Ttl: 300, + }) + require.NoError(t, err) + assert.Equal(t, testDNSRecord, *ret) + }) +} + +func TestDNSRecord_Create_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.CreateRecord(context.Background(), "zone123", api.PostApiDnsZonesZoneIdRecordsJSONRequestBody{ + Name: "www", + Content: "192.168.1.1", + Type: api.DNSRecordTypeA, + Ttl: 300, + }) + assert.Error(t, err) + assert.Equal(t, "Invalid record", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestDNSRecord_Update_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + reqBytes, err := io.ReadAll(r.Body) + require.NoError(t, err) + var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody + err = json.Unmarshal(reqBytes, &req) + require.NoError(t, err) + assert.Equal(t, "api", req.Name) + assert.Equal(t, "192.168.1.2", req.Content) + retBytes, _ := json.Marshal(testDNSRecord) + _, err = w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{ + Name: "api", + Content: "192.168.1.2", + Type: api.DNSRecordTypeA, + Ttl: 300, + }) + require.NoError(t, err) + assert.Equal(t, testDNSRecord, *ret) + }) +} + +func TestDNSRecord_Update_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Invalid record", Code: 400}) + w.WriteHeader(400) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + ret, err := c.DNSZones.UpdateRecord(context.Background(), "zone123", "record123", api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody{ + Name: "api", + Content: "192.168.1.2", + Type: api.DNSRecordTypeA, + Ttl: 300, + }) + assert.Error(t, err) + assert.Equal(t, "Invalid record", err.Error()) + assert.Nil(t, ret) + }) +} + +func TestDNSRecord_Delete_200(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(200) + }) + err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123") + require.NoError(t, err) + }) +} + +func TestDNSRecord_Delete_Err(t *testing.T) { + withMockClient(func(c *rest.Client, mux *http.ServeMux) { + mux.HandleFunc("/api/dns/zones/zone123/records/record123", func(w http.ResponseWriter, r *http.Request) { + retBytes, _ := json.Marshal(util.ErrorResponse{Message: "Not found", Code: 404}) + w.WriteHeader(404) + _, err := w.Write(retBytes) + require.NoError(t, err) + }) + err := c.DNSZones.DeleteRecord(context.Background(), "zone123", "record123") + assert.Error(t, err) + assert.Equal(t, "Not found", err.Error()) + }) +} + +func TestDNSZones_Integration(t *testing.T) { + enabled := true + zoneReq := api.ZoneRequest{ + Name: "test-zone", + Domain: "test.example.com", + Enabled: &enabled, + EnableSearchDomain: false, + DistributionGroups: []string{"cs1tnh0hhcjnqoiuebeg"}, + } + + recordReq := api.DNSRecordRequest{ + Name: "api.test.example.com", + Content: "192.168.1.100", + Type: api.DNSRecordTypeA, + Ttl: 300, + } + + withBlackBoxServer(t, func(c *rest.Client) { + zone, err := c.DNSZones.CreateZone(context.Background(), zoneReq) + require.NoError(t, err) + assert.Equal(t, "test-zone", zone.Name) + assert.Equal(t, "test.example.com", zone.Domain) + + zones, err := c.DNSZones.ListZones(context.Background()) + require.NoError(t, err) + assert.Equal(t, *zone, zones[0]) + + getZone, err := c.DNSZones.GetZone(context.Background(), zone.Id) + require.NoError(t, err) + assert.Equal(t, *zone, *getZone) + + zoneReq.Name = "updated-zone" + updatedZone, err := c.DNSZones.UpdateZone(context.Background(), zone.Id, zoneReq) + require.NoError(t, err) + assert.Equal(t, "updated-zone", updatedZone.Name) + + record, err := c.DNSZones.CreateRecord(context.Background(), zone.Id, recordReq) + require.NoError(t, err) + assert.Equal(t, "api.test.example.com", record.Name) + assert.Equal(t, "192.168.1.100", record.Content) + + records, err := c.DNSZones.ListRecords(context.Background(), zone.Id) + require.NoError(t, err) + assert.Equal(t, *record, records[0]) + + getRecord, err := c.DNSZones.GetRecord(context.Background(), zone.Id, record.Id) + require.NoError(t, err) + assert.Equal(t, *record, *getRecord) + + recordReq.Name = "www.test.example.com" + updatedRecord, err := c.DNSZones.UpdateRecord(context.Background(), zone.Id, record.Id, recordReq) + require.NoError(t, err) + assert.Equal(t, "www.test.example.com", updatedRecord.Name) + + err = c.DNSZones.DeleteRecord(context.Background(), zone.Id, record.Id) + require.NoError(t, err) + + records, err = c.DNSZones.ListRecords(context.Background(), zone.Id) + require.NoError(t, err) + assert.Len(t, records, 0) + + err = c.DNSZones.DeleteZone(context.Background(), zone.Id) + require.NoError(t, err) + + zones, err = c.DNSZones.ListZones(context.Background()) + require.NoError(t, err) + assert.Len(t, zones, 0) + }) +} diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 2d063a7b5bd..603a94a88a2 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -25,6 +25,8 @@ tags: description: Interact with and view information about routes. - name: DNS description: Interact with and view information about DNS configuration. + - name: DNS Zones + description: Interact with and view information about custom DNS zones. - name: Events description: View information about the account and network events. - name: Accounts @@ -1705,6 +1707,100 @@ components: example: ch8i4ug6lnn4g9hqv7m0 required: - disabled_management_groups + ZoneRequest: + type: object + properties: + name: + description: Zone name identifier + type: string + maxLength: 255 + minLength: 1 + example: Office Zone + domain: + description: Zone domain (FQDN) + type: string + example: example.com + enabled: + description: Zone status + type: boolean + default: true + enable_search_domain: + description: Enable this zone as a search domain + type: boolean + example: false + distribution_groups: + description: Group IDs that defines groups of peers that will resolve this zone + type: array + items: + type: string + example: ch8i4ug6lnn4g9hqv7m0 + required: + - name + - domain + - enable_search_domain + - distribution_groups + Zone: + allOf: + - type: object + properties: + id: + description: Zone ID + type: string + example: ch8i4ug6lnn4g9hqv7m0 + records: + description: DNS records associated with this zone + type: array + items: + $ref: '#/components/schemas/DNSRecord' + required: + - id + - enabled + - records + - $ref: '#/components/schemas/ZoneRequest' + DNSRecordType: + type: string + description: DNS record type + enum: + - A + - AAAA + - CNAME + example: A + DNSRecordRequest: + type: object + properties: + name: + description: FQDN for the DNS record. Must be a subdomain within or match the zone's domain. + type: string + example: www.example.com + type: + $ref: '#/components/schemas/DNSRecordType' + content: + description: DNS record content (IP address for A/AAAA, domain for CNAME) + type: string + maxLength: 255 + minLength: 1 + example: 192.168.1.1 + ttl: + description: Time to live in seconds + type: integer + minimum: 0 + example: 300 + required: + - name + - type + - content + - ttl + DNSRecord: + allOf: + - type: object + properties: + id: + description: DNS record ID + type: string + example: ch8i4ug6lnn4g9hqv7m0 + required: + - id + - $ref: '#/components/schemas/DNSRecordRequest' Event: type: object properties: @@ -4505,6 +4601,347 @@ paths: "$ref": "#/components/responses/forbidden" '500': "$ref": "#/components/responses/internal_error" + /api/dns/zones: + get: + summary: List all DNS Zones + description: Returns a list of all custom DNS zones + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + responses: + '200': + description: A JSON Array of DNS Zones + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/Zone' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a DNS Zone + description: Creates a new custom DNS zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + requestBody: + description: A DNS zone object + content: + 'application/json': + schema: + $ref: '#/components/schemas/ZoneRequest' + responses: + '200': + description: A JSON Object of the created DNS Zone + content: + application/json: + schema: + $ref: '#/components/schemas/Zone' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '500': + "$ref": "#/components/responses/internal_error" + /api/dns/zones/{zoneId}: + get: + summary: Retrieve a DNS Zone + description: Returns information about a specific DNS zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + responses: + '200': + description: A JSON Object of a DNS Zone + content: + application/json: + schema: + $ref: '#/components/schemas/Zone' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a DNS Zone + description: Updates a custom DNS zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + requestBody: + description: A DNS zone object + content: + 'application/json': + schema: + $ref: '#/components/schemas/ZoneRequest' + responses: + '200': + description: A JSON Object of the updated DNS Zone + content: + application/json: + schema: + $ref: '#/components/schemas/Zone' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a DNS Zone + description: Deletes a custom DNS zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + responses: + '200': + description: Zone deletion successful + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/dns/zones/{zoneId}/records: + get: + summary: List all DNS Records + description: Returns a list of all DNS records in a zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + responses: + '200': + description: A JSON Array of DNS Records + content: + application/json: + schema: + type: array + items: + $ref: '#/components/schemas/DNSRecord' + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + post: + summary: Create a DNS Record + description: Creates a new DNS record in a zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + requestBody: + description: A DNS record object + content: + 'application/json': + schema: + $ref: '#/components/schemas/DNSRecordRequest' + responses: + '200': + description: A JSON Object of the created DNS Record + content: + application/json: + schema: + $ref: '#/components/schemas/DNSRecord' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + /api/dns/zones/{zoneId}/records/{recordId}: + get: + summary: Retrieve a DNS Record + description: Returns information about a specific DNS record + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + - in: path + name: recordId + required: true + schema: + type: string + description: The unique identifier of a DNS record + example: chacbco6lnnbn6cg5s92 + responses: + '200': + description: A JSON Object of a DNS Record + content: + application/json: + schema: + $ref: '#/components/schemas/DNSRecord' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + put: + summary: Update a DNS Record + description: Updates a DNS record in a zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + - in: path + name: recordId + required: true + schema: + type: string + description: The unique identifier of a DNS record + example: chacbco6lnnbn6cg5s92 + requestBody: + description: A DNS record object + content: + 'application/json': + schema: + $ref: '#/components/schemas/DNSRecordRequest' + responses: + '200': + description: A JSON Object of the updated DNS Record + content: + application/json: + schema: + $ref: '#/components/schemas/DNSRecord' + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" + delete: + summary: Delete a DNS Record + description: Deletes a DNS record from a zone + tags: [ DNS Zones ] + security: + - BearerAuth: [ ] + - TokenAuth: [ ] + parameters: + - in: path + name: zoneId + required: true + schema: + type: string + description: The unique identifier of a zone + example: chacbco6lnnbn6cg5s91 + - in: path + name: recordId + required: true + schema: + type: string + description: The unique identifier of a DNS record + example: chacbco6lnnbn6cg5s92 + responses: + '200': + description: Record deletion successful + '400': + "$ref": "#/components/responses/bad_request" + '401': + "$ref": "#/components/responses/requires_authentication" + '403': + "$ref": "#/components/responses/forbidden" + '404': + "$ref": "#/components/responses/not_found" + '500': + "$ref": "#/components/responses/internal_error" /api/events/audit: get: summary: List all Audit Events diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index d3e4255483c..13116472318 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -12,6 +12,13 @@ const ( TokenAuthScopes = "TokenAuth.Scopes" ) +// Defines values for DNSRecordType. +const ( + DNSRecordTypeA DNSRecordType = "A" + DNSRecordTypeAAAA DNSRecordType = "AAAA" + DNSRecordTypeCNAME DNSRecordType = "CNAME" +) + // Defines values for EventActivityCode. const ( EventActivityCodeAccountCreate EventActivityCode = "account.create" @@ -407,6 +414,42 @@ type CreateSetupKeyRequest struct { UsageLimit int `json:"usage_limit"` } +// DNSRecord defines model for DNSRecord. +type DNSRecord struct { + // Content DNS record content (IP address for A/AAAA, domain for CNAME) + Content string `json:"content"` + + // Id DNS record ID + Id string `json:"id"` + + // Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain. + Name string `json:"name"` + + // Ttl Time to live in seconds + Ttl int `json:"ttl"` + + // Type DNS record type + Type DNSRecordType `json:"type"` +} + +// DNSRecordRequest defines model for DNSRecordRequest. +type DNSRecordRequest struct { + // Content DNS record content (IP address for A/AAAA, domain for CNAME) + Content string `json:"content"` + + // Name FQDN for the DNS record. Must be a subdomain within or match the zone's domain. + Name string `json:"name"` + + // Ttl Time to live in seconds + Ttl int `json:"ttl"` + + // Type DNS record type + Type DNSRecordType `json:"type"` +} + +// DNSRecordType DNS record type +type DNSRecordType string + // DNSSettings defines model for DNSSettings. type DNSSettings struct { // DisabledManagementGroups Groups whose DNS management is disabled @@ -1863,6 +1906,48 @@ type UserRequest struct { Role string `json:"role"` } +// Zone defines model for Zone. +type Zone struct { + // DistributionGroups Group IDs that defines groups of peers that will resolve this zone + DistributionGroups []string `json:"distribution_groups"` + + // Domain Zone domain (FQDN) + Domain string `json:"domain"` + + // EnableSearchDomain Enable this zone as a search domain + EnableSearchDomain bool `json:"enable_search_domain"` + + // Enabled Zone status + Enabled bool `json:"enabled"` + + // Id Zone ID + Id string `json:"id"` + + // Name Zone name identifier + Name string `json:"name"` + + // Records DNS records associated with this zone + Records []DNSRecord `json:"records"` +} + +// ZoneRequest defines model for ZoneRequest. +type ZoneRequest struct { + // DistributionGroups Group IDs that defines groups of peers that will resolve this zone + DistributionGroups []string `json:"distribution_groups"` + + // Domain Zone domain (FQDN) + Domain string `json:"domain"` + + // EnableSearchDomain Enable this zone as a search domain + EnableSearchDomain bool `json:"enable_search_domain"` + + // Enabled Zone status + Enabled *bool `json:"enabled,omitempty"` + + // Name Zone name identifier + Name string `json:"name"` +} + // GetApiEventsNetworkTrafficParams defines parameters for GetApiEventsNetworkTraffic. type GetApiEventsNetworkTrafficParams struct { // Page Page number @@ -1947,6 +2032,18 @@ type PutApiDnsNameserversNsgroupIdJSONRequestBody = NameserverGroupRequest // PutApiDnsSettingsJSONRequestBody defines body for PutApiDnsSettings for application/json ContentType. type PutApiDnsSettingsJSONRequestBody = DNSSettings +// PostApiDnsZonesJSONRequestBody defines body for PostApiDnsZones for application/json ContentType. +type PostApiDnsZonesJSONRequestBody = ZoneRequest + +// PutApiDnsZonesZoneIdJSONRequestBody defines body for PutApiDnsZonesZoneId for application/json ContentType. +type PutApiDnsZonesZoneIdJSONRequestBody = ZoneRequest + +// PostApiDnsZonesZoneIdRecordsJSONRequestBody defines body for PostApiDnsZonesZoneIdRecords for application/json ContentType. +type PostApiDnsZonesZoneIdRecordsJSONRequestBody = DNSRecordRequest + +// PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody defines body for PutApiDnsZonesZoneIdRecordsRecordId for application/json ContentType. +type PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody = DNSRecordRequest + // PostApiGroupsJSONRequestBody defines body for PostApiGroups for application/json ContentType. type PostApiGroupsJSONRequestBody = GroupRequest diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 09676847e7c..ea02173e900 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -252,3 +252,13 @@ func NewOperationNotFoundError(operation operations.Operation) error { func NewRouteNotFoundError(routeID string) error { return Errorf(NotFound, "route: %s not found", routeID) } + +// NewZoneNotFoundError creates a new Error with NotFound type for a missing dns zone. +func NewZoneNotFoundError(zoneID string) error { + return Errorf(NotFound, "zone: %s not found", zoneID) +} + +// NewDNSRecordNotFoundError creates a new Error with NotFound type for a missing dns record. +func NewDNSRecordNotFoundError(recordID string) error { + return Errorf(NotFound, "dns record: %s not found", recordID) +}