diff --git a/pkg/node/node_test.go b/pkg/node/node_test.go index 31048f11..f7af0310 100644 --- a/pkg/node/node_test.go +++ b/pkg/node/node_test.go @@ -23,6 +23,7 @@ import ( . "github.com/onsi/gomega" "github.com/mudler/edgevpn/pkg/blockchain" + "github.com/mudler/edgevpn/pkg/discovery" "github.com/mudler/edgevpn/pkg/logger" . "github.com/mudler/edgevpn/pkg/node" ) @@ -33,11 +34,17 @@ var _ = Describe("Node", func() { l := Logger(logger.New(log.LevelFatal)) - Context("Configuration", func() { + Context("Node configuration validation", func() { It("fails if is not valid", func() { - _, err := New(FromBase64(true, true, " ", nil, nil), WithStore(&blockchain.MemoryStore{}), l) + _, err := New(FromBase64(true, true, " ", &discovery.DHT{}, &discovery.MDNS{}), WithStore(&blockchain.MemoryStore{}), l) Expect(err).To(HaveOccurred()) + _, err = New(FromBase64(true, true, token, nil, nil), WithStore(&blockchain.MemoryStore{}), l) + Expect(err).To(HaveOccurred()) + }) + + It("passes if when valid", func() { + _, err := New(FromBase64(true, true, token, &discovery.DHT{}, &discovery.MDNS{}), WithStore(&blockchain.MemoryStore{}), l) Expect(err).ToNot(HaveOccurred()) }) }) diff --git a/pkg/node/options.go b/pkg/node/options.go index 44ae8521..9d75fb7e 100644 --- a/pkg/node/options.go +++ b/pkg/node/options.go @@ -273,12 +273,12 @@ func (y YAMLConnectionConfig) YAML() string { return string(bytesData) } -func (y YAMLConnectionConfig) copy(mdns, dht bool, cfg *Config, d *discovery.DHT, m *discovery.MDNS) { +func (y YAMLConnectionConfig) copy(mdns, dht bool, cfg *Config, d *discovery.DHT, m *discovery.MDNS) error { if d == nil { - d = discovery.NewDHT() + return errors.New("DHT is nil") } if m == nil { - m = &discovery.MDNS{} + return errors.New("MDNS is nil") } d.RefreshDiscoveryTime = cfg.DiscoveryInterval @@ -301,6 +301,8 @@ func (y YAMLConnectionConfig) copy(mdns, dht bool, cfg *Config, d *discovery.DHT } cfg.SealKeyLength = y.OTP.Crypto.Length cfg.MaxMessageSize = y.MaxMessageSize + + return nil } const defaultKeyLength = 43 @@ -357,8 +359,7 @@ func FromYaml(enablemDNS, enableDHT bool, path string, d *discovery.DHT, m *disc return errors.Wrap(err, "parsing yaml") } - t.copy(enablemDNS, enableDHT, cfg, d, m) - return nil + return t.copy(enablemDNS, enableDHT, cfg, d, m) } } @@ -376,7 +377,6 @@ func FromBase64(enablemDNS, enableDHT bool, bb string, d *discovery.DHT, m *disc if err := yaml.Unmarshal(configDec, &t); err != nil { return errors.Wrap(err, "parsing yaml") } - t.copy(enablemDNS, enableDHT, cfg, d, m) - return nil + return t.copy(enablemDNS, enableDHT, cfg, d, m) } }