diff --git a/client/internal/dns/host.go b/client/internal/dns/host.go index f7dc46a6ba6..4fbe3fc736c 100644 --- a/client/internal/dns/host.go +++ b/client/internal/dns/host.go @@ -20,7 +20,7 @@ type hostManager interface { type SystemDNSSettings struct { Domains []string - ServerIP netip.Addr + ServerIPs []netip.Addr ServerPort int } diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 71badf0d4ba..1e4159eb4db 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -79,10 +79,11 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, ".")) } + configServerIPs := []netip.Addr{config.ServerIP} matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) var err error if len(matchDomains) != 0 { - err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) + err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), configServerIPs, config.ServerPort) } else { log.Infof("removing match domains from the system") err = s.removeKeyFromSystemConfig(matchKey) @@ -94,7 +95,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { - err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort) + err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), configServerIPs, config.ServerPort) } else { log.Infof("removing search domains from the system") err = s.removeKeyFromSystemConfig(searchKey) @@ -168,20 +169,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { } func (s *systemConfigurator) addLocalDNS() error { - if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { + if len(s.systemDNSSettings.ServerIPs) == 0 || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { + if len(s.systemDNSSettings.ServerIPs) == 0 || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") return nil } if err := s.addSearchDomains( localKey, - strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIPs, s.systemDNSSettings.ServerPort, ); err != nil { return fmt.Errorf("add search domains: %w", err) } @@ -190,7 +191,7 @@ func (s *systemConfigurator) addLocalDNS() error { } func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 && !force { + if len(s.systemDNSSettings.ServerIPs) != 0 && len(s.systemDNSSettings.Domains) != 0 && !force { return nil } @@ -203,31 +204,22 @@ func (s *systemConfigurator) recordSystemDNSSettings(force bool) error { return nil } -func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { - primaryServiceKey, _, err := s.getPrimaryService() - if err != nil || primaryServiceKey == "" { - return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) - } - dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) - line := buildCommandLine("show", dnsServiceKey, "") - stdinCommands := wrapCommand(line) - - b, err := runSystemConfigCommand(stdinCommands) - if err != nil { - return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err) - } - +func parseSystemConfigOutput(output []byte) (SystemDNSSettings, error) { var dnsSettings SystemDNSSettings + localDomainsMap := make(map[string]struct{}) inSearchDomainsArray := false inServerAddressesArray := false - scanner := bufio.NewScanner(bytes.NewReader(b)) + scanner := bufio.NewScanner(bytes.NewReader(output)) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) switch { case strings.HasPrefix(line, "DomainName :"): domainName := strings.TrimSpace(strings.Split(line, ":")[1]) - dnsSettings.Domains = append(dnsSettings.Domains, domainName) + if _, exists := localDomainsMap[domainName]; !exists { + localDomainsMap[domainName] = struct{}{} + dnsSettings.Domains = append(dnsSettings.Domains, domainName) + } case line == "SearchDomains : {": inSearchDomainsArray = true continue @@ -241,12 +233,16 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { if inSearchDomainsArray { searchDomain := strings.Split(line, " : ")[1] - dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) + if _, exists := localDomainsMap[searchDomain]; !exists { + localDomainsMap[searchDomain] = struct{}{} + dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) + } } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { - dnsSettings.ServerIP = ip.Unmap() - inServerAddressesArray = false // Stop reading after finding the first IPv4 address + if ip, err := netip.ParseAddr(address); err == nil { + if ip.IsValid() { + dnsSettings.ServerIPs = append(dnsSettings.ServerIPs, ip.Unmap()) + } } } } @@ -255,14 +251,30 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { return dnsSettings, err } - // default to 53 port dnsSettings.ServerPort = DefaultPort return dnsSettings, nil } -func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { - err := s.addDNSState(key, domains, ip, port, true) +func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { + primaryServiceKey, _, err := s.getPrimaryService() + if err != nil || primaryServiceKey == "" { + return SystemDNSSettings{}, fmt.Errorf("couldn't find the primary service key: %w", err) + } + dnsServiceKey := getKeyWithInput(primaryServiceStateKeyFormat, primaryServiceKey) + line := buildCommandLine("show", dnsServiceKey, "") + stdinCommands := wrapCommand(line) + + b, err := runSystemConfigCommand(stdinCommands) + if err != nil { + return SystemDNSSettings{}, fmt.Errorf("sending the command: %w", err) + } + + return parseSystemConfigOutput(b) +} + +func (s *systemConfigurator) addSearchDomains(key, domains string, ips []netip.Addr, port int) error { + err := s.addDNSState(key, domains, ips, port, true) if err != nil { return fmt.Errorf("add dns state: %w", err) } @@ -274,8 +286,8 @@ func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr return nil } -func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error { - err := s.addDNSState(key, domains, dnsServer, port, false) +func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServers []netip.Addr, port int) error { + err := s.addDNSState(key, domains, dnsServers, port, false) if err != nil { return fmt.Errorf("add dns state: %w", err) } @@ -287,14 +299,20 @@ func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer neti return nil } -func (s *systemConfigurator) addDNSState(state, domains string, dnsServer netip.Addr, port int, enableSearch bool) error { +func (s *systemConfigurator) addDNSState(state, domains string, dnsServers []netip.Addr, port int, enableSearch bool) error { noSearch := "1" if enableSearch { noSearch = "0" } + + servers := make([]string, 0, len(dnsServers)) + for _, serverIP := range dnsServers { + servers = append(servers, serverIP.String()) + } + serversStr := strings.Join(servers, " ") lines := buildAddCommandLine(keySupplementalMatchDomains, arraySymbol+domains) lines += buildAddCommandLine(keySupplementalMatchDomainsNoSearch, digitSymbol+noSearch) - lines += buildAddCommandLine(keyServerAddresses, arraySymbol+dnsServer.String()) + lines += buildAddCommandLine(keyServerAddresses, arraySymbol+serversStr) lines += buildAddCommandLine(keyServerPort, digitSymbol+strconv.Itoa(port)) addDomainCommand := buildCreateStateWithOperation(state, lines) diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index c4efd17b06e..424e0ab0185 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -109,3 +109,147 @@ func removeTestDNSKey(key string) error { _, err := cmd.CombinedOutput() return err } + +func TestParseSystemConfigOutput_Complete(t *testing.T) { + mockOutput := ` { + DomainName : example.com + SearchDomains : { + 0 : internal.local + 1 : corp.example.com + } + ServerAddresses : { + 0 : 192.168.1.1 + 1 : 8.8.8.8 + } +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Equal(t, 53, result.ServerPort) + assert.Len(t, result.Domains, 3) + assert.Contains(t, result.Domains, "example.com") + assert.Contains(t, result.Domains, "internal.local") + assert.Contains(t, result.Domains, "corp.example.com") + + assert.Len(t, result.ServerIPs, 2) + assert.Equal(t, "192.168.1.1", result.ServerIPs[0].String()) + assert.Equal(t, "8.8.8.8", result.ServerIPs[1].String()) +} + +func TestParseSystemConfigOutput_MultipleServers(t *testing.T) { + mockOutput := ` { + DomainName : test.local + ServerAddresses : { + 0 : 192.168.1.1 + 1 : 10.0.0.1 + 2 : 2001:4860:4860::8888 + 3 : fd00::1 + } +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Len(t, result.ServerIPs, 4) + assert.Equal(t, "192.168.1.1", result.ServerIPs[0].String()) + assert.Equal(t, "10.0.0.1", result.ServerIPs[1].String()) + assert.Equal(t, "2001:4860:4860::8888", result.ServerIPs[2].String()) + assert.Equal(t, "fd00::1", result.ServerIPs[3].String()) +} + +func TestParseSystemConfigOutput_DomainDeduplication(t *testing.T) { + mockOutput := ` { + DomainName : example.com + SearchDomains : { + 0 : example.com + 1 : internal.local + 2 : example.com + } + ServerAddresses : { + 0 : 192.168.1.1 + } +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Len(t, result.Domains, 2) + assert.Contains(t, result.Domains, "example.com") + assert.Contains(t, result.Domains, "internal.local") + + domainCount := make(map[string]int) + for _, domain := range result.Domains { + domainCount[domain]++ + } + assert.Equal(t, 1, domainCount["example.com"]) +} + +func TestParseSystemConfigOutput_EmptyOutput(t *testing.T) { + result, err := parseSystemConfigOutput([]byte("")) + require.NoError(t, err) + + assert.Equal(t, 53, result.ServerPort) + assert.Empty(t, result.Domains) + assert.Empty(t, result.ServerIPs) +} + +func TestParseSystemConfigOutput_InvalidIP(t *testing.T) { + mockOutput := ` { + DomainName : test.local + ServerAddresses : { + 0 : 192.168.1.1 + 1 : invalid-ip + 2 : 8.8.8.8 + 3 : 999.999.999.999 + } +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Len(t, result.ServerIPs, 2) + assert.Equal(t, "192.168.1.1", result.ServerIPs[0].String()) + assert.Equal(t, "8.8.8.8", result.ServerIPs[1].String()) +} + +func TestParseSystemConfigOutput_OnlyDomainName(t *testing.T) { + mockOutput := ` { + DomainName : example.com +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Len(t, result.Domains, 1) + assert.Equal(t, "example.com", result.Domains[0]) + assert.Empty(t, result.ServerIPs) + assert.Equal(t, 53, result.ServerPort) +} + +func TestParseSystemConfigOutput_NestedArrays(t *testing.T) { + mockOutput := ` { + DomainName : example.com + SearchDomains : { + 0 : search1.local + 1 : search2.local + } + ServerAddresses : { + 0 : 192.168.1.1 + 1 : 192.168.1.2 + } + OtherField : value +}` + + result, err := parseSystemConfigOutput([]byte(mockOutput)) + require.NoError(t, err) + + assert.Len(t, result.Domains, 3) + assert.Contains(t, result.Domains, "example.com") + assert.Contains(t, result.Domains, "search1.local") + assert.Contains(t, result.Domains, "search2.local") + + assert.Len(t, result.ServerIPs, 2) + assert.Equal(t, "192.168.1.1", result.ServerIPs[0].String()) + assert.Equal(t, "192.168.1.2", result.ServerIPs[1].String()) +}