From 0abb04967ffd5526b639ce3b2ae6724399e0a225 Mon Sep 17 00:00:00 2001
From: Michael Mokrysz <michael.mokrysz@sap.com>
Date: Wed, 12 May 2021 16:51:31 +0100
Subject: [PATCH] Support talking to NATS over mTLS

Until now Gorouter has been unable to encrypt its connection to NATS.
Mutual TLS has been added for other components that talk to NATS, but
not to Gorouter.

This commit adds support for configuring the NATS connection with
Mutual TLS. This code is based heavily on the implementation in
route-emitter [1].

The config YAML's structure gets some changes: the username and password
is no longer provided separately for every NATS machine. This isn't a
limitation in real-world Cloud Foundry deployments. Port is repeated for
each NATS machine so that the integration tests can still run multiple
NATS on multiple ports.

[1] https://github.com/cloudfoundry/route-emitter/tree/master/diegonats
---
 config/config.go                       | 56 ++++++++++-----
 config/config_test.go                  | 95 ++++++++++++++++++++------
 integration/common_integration_test.go |  2 +-
 integration/main_test.go               |  8 +--
 mbus/client.go                         | 17 ++++-
 test_util/helpers.go                   | 17 ++---
 6 files changed, 142 insertions(+), 53 deletions(-)

diff --git a/config/config.go b/config/config.go
index 5817472c4..9865d06c5 100644
--- a/config/config.go
+++ b/config/config.go
@@ -78,10 +78,25 @@ var defaultStatusConfig = StatusConfig{
 }
 
 type NatsConfig struct {
-	Host string `yaml:"host"`
-	Port uint16 `yaml:"port"`
-	User string `yaml:"user"`
-	Pass string `yaml:"pass"`
+	Hosts                 []NatsHost       `yaml:"hosts"`
+	User                  string           `yaml:"user"`
+	Pass                  string           `yaml:"pass"`
+	TLSEnabled            bool             `yaml:"tls_enabled"`
+	CACerts               string           `yaml:"ca_certs"`
+	CAPool                *x509.CertPool   `yaml:"-"`
+	ClientAuthCertificate tls.Certificate  `yaml:"-"`
+	TLSPem                `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
+}
+
+type NatsHost struct {
+	Hostname string
+	Port     uint16
+}
+
+var defaultNatsConfig = NatsConfig{
+	Hosts: []NatsHost{{Hostname: "localhost", Port: 42222}},
+	User:  "",
+	Pass:  "",
 }
 
 type RoutingApiConfig struct {
@@ -94,13 +109,6 @@ type RoutingApiConfig struct {
 	TLSPem                `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
 }
 
-var defaultNatsConfig = NatsConfig{
-	Host: "localhost",
-	Port: 4222,
-	User: "",
-	Pass: "",
-}
-
 type OAuthConfig struct {
 	TokenEndpoint     string `yaml:"token_endpoint"`
 	Port              int    `yaml:"port"`
@@ -181,7 +189,7 @@ type HTTPRewriteResponses struct {
 
 type Config struct {
 	Status          StatusConfig      `yaml:"status,omitempty"`
-	Nats            []NatsConfig      `yaml:"nats,omitempty"`
+	Nats            NatsConfig        `yaml:"nats,omitempty"`
 	Logging         LoggingConfig     `yaml:"logging,omitempty"`
 	Port            uint16            `yaml:"port,omitempty"`
 	Index           uint              `yaml:"index,omitempty"`
@@ -283,7 +291,7 @@ type Config struct {
 
 var defaultConfig = Config{
 	Status:        defaultStatusConfig,
-	Nats:          []NatsConfig{defaultNatsConfig},
+	Nats:          defaultNatsConfig,
 	Logging:       defaultLoggingConfig,
 	Port:          8081,
 	Index:         0,
@@ -399,6 +407,21 @@ func (c *Config) Process() error {
 		c.RoutingApi.CAPool = certPool
 	}
 
+	if c.Nats.TLSEnabled {
+		certificate, err := tls.X509KeyPair([]byte(c.Nats.CertChain), []byte(c.Nats.PrivateKey))
+		if err != nil {
+			errMsg := fmt.Sprintf("Error loading NATS key pair: %s", err.Error())
+			return fmt.Errorf(errMsg)
+		}
+		c.Nats.ClientAuthCertificate = certificate
+
+		certPool := x509.NewCertPool()
+		if ok := certPool.AppendCertsFromPEM([]byte(c.Nats.CACerts)); !ok {
+			return fmt.Errorf("Error while adding CACerts to gorouter's routing-api cert pool: \n%s\n", c.Nats.CACerts)
+		}
+		c.Nats.CAPool = certPool
+	}
+
 	if c.EnableSSL {
 		switch c.ClientCertificateValidationString {
 		case "none":
@@ -650,11 +673,11 @@ func convertCipherStringToInt(cipherStrs []string, cipherMap map[string]uint16)
 
 func (c *Config) NatsServers() []string {
 	var natsServers []string
-	for _, info := range c.Nats {
+	for _, host := range c.Nats.Hosts {
 		uri := url.URL{
 			Scheme: "nats",
-			User:   url.UserPassword(info.User, info.Pass),
-			Host:   fmt.Sprintf("%s:%d", info.Host, info.Port),
+			User:   url.UserPassword(c.Nats.User, c.Nats.Pass),
+			Host:   fmt.Sprintf("%s:%d", host.Hostname, host.Port),
 		}
 		natsServers = append(natsServers, uri.String())
 	}
@@ -667,7 +690,6 @@ func (c *Config) RoutingApiEnabled() bool {
 }
 
 func (c *Config) Initialize(configYAML []byte) error {
-	c.Nats = []NatsConfig{}
 	return yaml.Unmarshal(configYAML, &c)
 }
 
diff --git a/config/config_test.go b/config/config_test.go
index 6d2f0d5d2..4633e9d07 100644
--- a/config/config_test.go
+++ b/config/config_test.go
@@ -116,22 +116,78 @@ endpoint_keep_alive_probe_interval: 500ms
 			Expect(config.EndpointKeepAliveProbeInterval).To(Equal(500 * time.Millisecond))
 		})
 
-		It("sets nats config", func() {
-			var b = []byte(`
+		Context("NATS Config", func() {
+			It("handles basic nats config", func() {
+				var b = []byte(`
 nats:
-  - host: remotehost
+  user: user
+  pass: pass
+  hosts:
+  - hostname: remotehost
     port: 4223
-    user: user
-    pass: pass
 `)
-			err := config.Initialize(b)
-			Expect(err).ToNot(HaveOccurred())
+				err := config.Initialize(b)
+				Expect(err).ToNot(HaveOccurred())
+
+				Expect(config.Nats.User).To(Equal("user"))
+				Expect(config.Nats.Pass).To(Equal("pass"))
+				Expect(config.Nats.Hosts).To(HaveLen(1))
+				Expect(config.Nats.Hosts[0].Hostname).To(Equal("remotehost"))
+				Expect(config.Nats.Hosts[0].Port).To(Equal(uint16(4223)))
+			})
+
+			Context("when TLSEnabled is set to true", func() {
+				var (
+					err           error
+					configSnippet *Config
+					caCert        tls.Certificate
+					clientPair    tls.Certificate
+				)
+
+				createYMLSnippet := func(snippet *Config) []byte {
+					cfgBytes, err := yaml.Marshal(snippet)
+					Expect(err).ToNot(HaveOccurred())
+					return cfgBytes
+				}
+
+				BeforeEach(func() {
+					caCertChain := test_util.CreateSignedCertWithRootCA(test_util.CertNames{CommonName: "spinach.com"})
+					clientKeyPEM, clientCertPEM := test_util.CreateKeyPair("potato.com")
+
+					caCert, err = tls.X509KeyPair(append(caCertChain.CertPEM, caCertChain.CACertPEM...), caCertChain.PrivKeyPEM)
+					Expect(err).ToNot(HaveOccurred())
+					clientPair, err = tls.X509KeyPair(clientCertPEM, clientKeyPEM)
+					Expect(err).ToNot(HaveOccurred())
+
+					configSnippet = &Config{
+						Nats: NatsConfig{
+							TLSEnabled: true,
+							CACerts:    fmt.Sprintf("%s%s", caCertChain.CertPEM, caCertChain.CACertPEM),
+							TLSPem: TLSPem{
+								CertChain:  string(clientCertPEM),
+								PrivateKey: string(clientKeyPEM),
+							},
+						},
+					}
+				})
 
-			Expect(config.Nats).To(HaveLen(1))
-			Expect(config.Nats[0].Host).To(Equal("remotehost"))
-			Expect(config.Nats[0].Port).To(Equal(uint16(4223)))
-			Expect(config.Nats[0].User).To(Equal("user"))
-			Expect(config.Nats[0].Pass).To(Equal("pass"))
+				It("configures TLS", func() {
+					configBytes := createYMLSnippet(configSnippet)
+					err = config.Initialize(configBytes)
+					Expect(err).NotTo(HaveOccurred())
+					err = config.Process()
+					Expect(err).NotTo(HaveOccurred())
+
+					Expect(config.Nats.CAPool).ToNot(BeNil())
+					poolSubjects := config.Nats.CAPool.Subjects()
+					parsedCert, err := x509.ParseCertificate(caCert.Certificate[0])
+					Expect(err).NotTo(HaveOccurred())
+					expectedSubject := parsedCert.RawSubject
+
+					Expect(string(poolSubjects[0])).To(Equal(string(expectedSubject)))
+					Expect(config.Nats.ClientAuthCertificate).To(Equal(clientPair))
+				})
+			})
 		})
 
 		Context("Suspend Pruning option", func() {
@@ -752,14 +808,13 @@ secure_cookies: false
 		Describe("NatsServers", func() {
 			var b = []byte(`
 nats:
-  - host: remotehost
-    port: 4223
-    user: user
-    pass: pass
-  - host: remotehost2
+  user: user
+  pass: pass
+  hosts:
+  - hostname: remotehost
     port: 4223
-    user: user2
-    pass: pass2
+  - hostname: remotehost2
+    port: 4224
 `)
 
 			It("returns a slice of the configured NATS servers", func() {
@@ -768,7 +823,7 @@ nats:
 
 				natsServers := config.NatsServers()
 				Expect(natsServers[0]).To(Equal("nats://user:pass@remotehost:4223"))
-				Expect(natsServers[1]).To(Equal("nats://user2:pass2@remotehost2:4223"))
+				Expect(natsServers[1]).To(Equal("nats://user:pass@remotehost2:4224"))
 			})
 		})
 
diff --git a/integration/common_integration_test.go b/integration/common_integration_test.go
index af605d32b..bc7718423 100644
--- a/integration/common_integration_test.go
+++ b/integration/common_integration_test.go
@@ -254,7 +254,7 @@ func (s *testState) registerAndWait(rm mbus.RegistryMessage) {
 func (s *testState) StartGorouter() *Session {
 	Expect(s.cfg).NotTo(BeNil(), "set up test cfg before calling this function")
 
-	s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats[0].Port))
+	s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats.Hosts[0].Port))
 	s.natsRunner.Start()
 
 	var err error
diff --git a/integration/main_test.go b/integration/main_test.go
index b85d6aae5..d3be6242a 100644
--- a/integration/main_test.go
+++ b/integration/main_test.go
@@ -1111,14 +1111,14 @@ func hostnameAndPort(url string) (string, int) {
 }
 
 func newMessageBus(c *config.Config) (*nats.Conn, error) {
-	natsMembers := make([]string, len(c.Nats))
+	natsMembers := make([]string, len(c.Nats.Hosts))
 	options := nats.DefaultOptions
 	options.PingInterval = 200 * time.Millisecond
-	for _, info := range c.Nats {
+	for _, host := range c.Nats.Hosts {
 		uri := url.URL{
 			Scheme: "nats",
-			User:   url.UserPassword(info.User, info.Pass),
-			Host:   fmt.Sprintf("%s:%d", info.Host, info.Port),
+			User:   url.UserPassword(c.Nats.User, c.Nats.Pass),
+			Host:   fmt.Sprintf("%s:%d", host.Hostname, host.Port),
 		}
 		natsMembers = append(natsMembers, uri.String())
 	}
diff --git a/mbus/client.go b/mbus/client.go
index e0451a872..e52349134 100644
--- a/mbus/client.go
+++ b/mbus/client.go
@@ -8,6 +8,7 @@ import (
 
 	"code.cloudfoundry.org/gorouter/config"
 	"code.cloudfoundry.org/gorouter/logger"
+	"code.cloudfoundry.org/tlsconfig"
 	"github.com/nats-io/nats.go"
 	"github.com/uber-go/zap"
 )
@@ -54,10 +55,20 @@ func Connect(c *config.Config, reconnected chan<- Signal, l logger.Logger) *nats
 }
 
 func natsOptions(l logger.Logger, c *config.Config, natsHost *atomic.Value, reconnected chan<- Signal) nats.Options {
-	natsServers := c.NatsServers()
-
 	options := nats.DefaultOptions
-	options.Servers = natsServers
+	options.Servers = c.NatsServers()
+	if c.Nats.TLSEnabled {
+		var err error
+		options.TLSConfig, err = tlsconfig.Build(
+			tlsconfig.WithInternalServiceDefaults(),
+			tlsconfig.WithIdentity(c.Nats.ClientAuthCertificate),
+		).Client(
+			tlsconfig.WithAuthority(c.Nats.CAPool),
+		)
+		if err != nil {
+			l.Fatal("nats-tls-config-invalid", zap.Object("error", err))
+		}
+	}
 	options.PingInterval = c.NatsClientPingInterval
 	options.MaxReconnect = -1
 	notDisconnected := make(chan Signal)
diff --git a/test_util/helpers.go b/test_util/helpers.go
index 802a6eed6..cc6907807 100644
--- a/test_util/helpers.go
+++ b/test_util/helpers.go
@@ -240,14 +240,15 @@ func generateConfig(statusPort, proxyPort uint16, natsPorts ...uint16) *config.C
 		Pass: "pass",
 	}
 
-	c.Nats = []config.NatsConfig{}
-	for _, natsPort := range natsPorts {
-		c.Nats = append(c.Nats, config.NatsConfig{
-			Host: "localhost",
-			Port: natsPort,
-			User: "nats",
-			Pass: "nats",
-		})
+	natsHosts := make([]config.NatsHost, len(natsPorts))
+	for i, natsPort := range natsPorts {
+		natsHosts[i].Hostname = "localhost"
+		natsHosts[i].Port = natsPort
+	}
+	c.Nats = config.NatsConfig{
+		User:  "nats",
+		Pass:  "nats",
+		Hosts: natsHosts,
 	}
 
 	c.Logging.Level = "debug"