diff --git a/config/config.go b/config/config.go index 5817472c4..9865d06c5 100644 --- a/config/config.go +++ b/config/config.go @@ -78,10 +78,25 @@ var defaultStatusConfig = StatusConfig{ } type NatsConfig struct { - Host string `yaml:"host"` - Port uint16 `yaml:"port"` - User string `yaml:"user"` - Pass string `yaml:"pass"` + Hosts []NatsHost `yaml:"hosts"` + User string `yaml:"user"` + Pass string `yaml:"pass"` + TLSEnabled bool `yaml:"tls_enabled"` + CACerts string `yaml:"ca_certs"` + CAPool *x509.CertPool `yaml:"-"` + ClientAuthCertificate tls.Certificate `yaml:"-"` + TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication +} + +type NatsHost struct { + Hostname string + Port uint16 +} + +var defaultNatsConfig = NatsConfig{ + Hosts: []NatsHost{{Hostname: "localhost", Port: 42222}}, + User: "", + Pass: "", } type RoutingApiConfig struct { @@ -94,13 +109,6 @@ type RoutingApiConfig struct { TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication } -var defaultNatsConfig = NatsConfig{ - Host: "localhost", - Port: 4222, - User: "", - Pass: "", -} - type OAuthConfig struct { TokenEndpoint string `yaml:"token_endpoint"` Port int `yaml:"port"` @@ -181,7 +189,7 @@ type HTTPRewriteResponses struct { type Config struct { Status StatusConfig `yaml:"status,omitempty"` - Nats []NatsConfig `yaml:"nats,omitempty"` + Nats NatsConfig `yaml:"nats,omitempty"` Logging LoggingConfig `yaml:"logging,omitempty"` Port uint16 `yaml:"port,omitempty"` Index uint `yaml:"index,omitempty"` @@ -283,7 +291,7 @@ type Config struct { var defaultConfig = Config{ Status: defaultStatusConfig, - Nats: []NatsConfig{defaultNatsConfig}, + Nats: defaultNatsConfig, Logging: defaultLoggingConfig, Port: 8081, Index: 0, @@ -399,6 +407,21 @@ func (c *Config) Process() error { c.RoutingApi.CAPool = certPool } + if c.Nats.TLSEnabled { + certificate, err := tls.X509KeyPair([]byte(c.Nats.CertChain), []byte(c.Nats.PrivateKey)) + if err != nil { + errMsg := fmt.Sprintf("Error loading NATS key pair: %s", err.Error()) + return fmt.Errorf(errMsg) + } + c.Nats.ClientAuthCertificate = certificate + + certPool := x509.NewCertPool() + if ok := certPool.AppendCertsFromPEM([]byte(c.Nats.CACerts)); !ok { + return fmt.Errorf("Error while adding CACerts to gorouter's routing-api cert pool: \n%s\n", c.Nats.CACerts) + } + c.Nats.CAPool = certPool + } + if c.EnableSSL { switch c.ClientCertificateValidationString { case "none": @@ -650,11 +673,11 @@ func convertCipherStringToInt(cipherStrs []string, cipherMap map[string]uint16) func (c *Config) NatsServers() []string { var natsServers []string - for _, info := range c.Nats { + for _, host := range c.Nats.Hosts { uri := url.URL{ Scheme: "nats", - User: url.UserPassword(info.User, info.Pass), - Host: fmt.Sprintf("%s:%d", info.Host, info.Port), + User: url.UserPassword(c.Nats.User, c.Nats.Pass), + Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port), } natsServers = append(natsServers, uri.String()) } @@ -667,7 +690,6 @@ func (c *Config) RoutingApiEnabled() bool { } func (c *Config) Initialize(configYAML []byte) error { - c.Nats = []NatsConfig{} return yaml.Unmarshal(configYAML, &c) } diff --git a/config/config_test.go b/config/config_test.go index 6d2f0d5d2..4633e9d07 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -116,22 +116,78 @@ endpoint_keep_alive_probe_interval: 500ms Expect(config.EndpointKeepAliveProbeInterval).To(Equal(500 * time.Millisecond)) }) - It("sets nats config", func() { - var b = []byte(` + Context("NATS Config", func() { + It("handles basic nats config", func() { + var b = []byte(` nats: - - host: remotehost + user: user + pass: pass + hosts: + - hostname: remotehost port: 4223 - user: user - pass: pass `) - err := config.Initialize(b) - Expect(err).ToNot(HaveOccurred()) + err := config.Initialize(b) + Expect(err).ToNot(HaveOccurred()) + + Expect(config.Nats.User).To(Equal("user")) + Expect(config.Nats.Pass).To(Equal("pass")) + Expect(config.Nats.Hosts).To(HaveLen(1)) + Expect(config.Nats.Hosts[0].Hostname).To(Equal("remotehost")) + Expect(config.Nats.Hosts[0].Port).To(Equal(uint16(4223))) + }) + + Context("when TLSEnabled is set to true", func() { + var ( + err error + configSnippet *Config + caCert tls.Certificate + clientPair tls.Certificate + ) + + createYMLSnippet := func(snippet *Config) []byte { + cfgBytes, err := yaml.Marshal(snippet) + Expect(err).ToNot(HaveOccurred()) + return cfgBytes + } + + BeforeEach(func() { + caCertChain := test_util.CreateSignedCertWithRootCA(test_util.CertNames{CommonName: "spinach.com"}) + clientKeyPEM, clientCertPEM := test_util.CreateKeyPair("potato.com") + + caCert, err = tls.X509KeyPair(append(caCertChain.CertPEM, caCertChain.CACertPEM...), caCertChain.PrivKeyPEM) + Expect(err).ToNot(HaveOccurred()) + clientPair, err = tls.X509KeyPair(clientCertPEM, clientKeyPEM) + Expect(err).ToNot(HaveOccurred()) + + configSnippet = &Config{ + Nats: NatsConfig{ + TLSEnabled: true, + CACerts: fmt.Sprintf("%s%s", caCertChain.CertPEM, caCertChain.CACertPEM), + TLSPem: TLSPem{ + CertChain: string(clientCertPEM), + PrivateKey: string(clientKeyPEM), + }, + }, + } + }) - Expect(config.Nats).To(HaveLen(1)) - Expect(config.Nats[0].Host).To(Equal("remotehost")) - Expect(config.Nats[0].Port).To(Equal(uint16(4223))) - Expect(config.Nats[0].User).To(Equal("user")) - Expect(config.Nats[0].Pass).To(Equal("pass")) + It("configures TLS", func() { + configBytes := createYMLSnippet(configSnippet) + err = config.Initialize(configBytes) + Expect(err).NotTo(HaveOccurred()) + err = config.Process() + Expect(err).NotTo(HaveOccurred()) + + Expect(config.Nats.CAPool).ToNot(BeNil()) + poolSubjects := config.Nats.CAPool.Subjects() + parsedCert, err := x509.ParseCertificate(caCert.Certificate[0]) + Expect(err).NotTo(HaveOccurred()) + expectedSubject := parsedCert.RawSubject + + Expect(string(poolSubjects[0])).To(Equal(string(expectedSubject))) + Expect(config.Nats.ClientAuthCertificate).To(Equal(clientPair)) + }) + }) }) Context("Suspend Pruning option", func() { @@ -752,14 +808,13 @@ secure_cookies: false Describe("NatsServers", func() { var b = []byte(` nats: - - host: remotehost - port: 4223 - user: user - pass: pass - - host: remotehost2 + user: user + pass: pass + hosts: + - hostname: remotehost port: 4223 - user: user2 - pass: pass2 + - hostname: remotehost2 + port: 4224 `) It("returns a slice of the configured NATS servers", func() { @@ -768,7 +823,7 @@ nats: natsServers := config.NatsServers() Expect(natsServers[0]).To(Equal("nats://user:pass@remotehost:4223")) - Expect(natsServers[1]).To(Equal("nats://user2:pass2@remotehost2:4223")) + Expect(natsServers[1]).To(Equal("nats://user:pass@remotehost2:4224")) }) }) diff --git a/integration/common_integration_test.go b/integration/common_integration_test.go index af605d32b..bc7718423 100644 --- a/integration/common_integration_test.go +++ b/integration/common_integration_test.go @@ -254,7 +254,7 @@ func (s *testState) registerAndWait(rm mbus.RegistryMessage) { func (s *testState) StartGorouter() *Session { Expect(s.cfg).NotTo(BeNil(), "set up test cfg before calling this function") - s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats[0].Port)) + s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats.Hosts[0].Port)) s.natsRunner.Start() var err error diff --git a/integration/main_test.go b/integration/main_test.go index b85d6aae5..d3be6242a 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -1111,14 +1111,14 @@ func hostnameAndPort(url string) (string, int) { } func newMessageBus(c *config.Config) (*nats.Conn, error) { - natsMembers := make([]string, len(c.Nats)) + natsMembers := make([]string, len(c.Nats.Hosts)) options := nats.DefaultOptions options.PingInterval = 200 * time.Millisecond - for _, info := range c.Nats { + for _, host := range c.Nats.Hosts { uri := url.URL{ Scheme: "nats", - User: url.UserPassword(info.User, info.Pass), - Host: fmt.Sprintf("%s:%d", info.Host, info.Port), + User: url.UserPassword(c.Nats.User, c.Nats.Pass), + Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port), } natsMembers = append(natsMembers, uri.String()) } diff --git a/mbus/client.go b/mbus/client.go index e0451a872..e52349134 100644 --- a/mbus/client.go +++ b/mbus/client.go @@ -8,6 +8,7 @@ import ( "code.cloudfoundry.org/gorouter/config" "code.cloudfoundry.org/gorouter/logger" + "code.cloudfoundry.org/tlsconfig" "github.com/nats-io/nats.go" "github.com/uber-go/zap" ) @@ -54,10 +55,20 @@ func Connect(c *config.Config, reconnected chan<- Signal, l logger.Logger) *nats } func natsOptions(l logger.Logger, c *config.Config, natsHost *atomic.Value, reconnected chan<- Signal) nats.Options { - natsServers := c.NatsServers() - options := nats.DefaultOptions - options.Servers = natsServers + options.Servers = c.NatsServers() + if c.Nats.TLSEnabled { + var err error + options.TLSConfig, err = tlsconfig.Build( + tlsconfig.WithInternalServiceDefaults(), + tlsconfig.WithIdentity(c.Nats.ClientAuthCertificate), + ).Client( + tlsconfig.WithAuthority(c.Nats.CAPool), + ) + if err != nil { + l.Fatal("nats-tls-config-invalid", zap.Object("error", err)) + } + } options.PingInterval = c.NatsClientPingInterval options.MaxReconnect = -1 notDisconnected := make(chan Signal) diff --git a/test_util/helpers.go b/test_util/helpers.go index 802a6eed6..cc6907807 100644 --- a/test_util/helpers.go +++ b/test_util/helpers.go @@ -240,14 +240,15 @@ func generateConfig(statusPort, proxyPort uint16, natsPorts ...uint16) *config.C Pass: "pass", } - c.Nats = []config.NatsConfig{} - for _, natsPort := range natsPorts { - c.Nats = append(c.Nats, config.NatsConfig{ - Host: "localhost", - Port: natsPort, - User: "nats", - Pass: "nats", - }) + natsHosts := make([]config.NatsHost, len(natsPorts)) + for i, natsPort := range natsPorts { + natsHosts[i].Hostname = "localhost" + natsHosts[i].Port = natsPort + } + c.Nats = config.NatsConfig{ + User: "nats", + Pass: "nats", + Hosts: natsHosts, } c.Logging.Level = "debug"