Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for talking to NATS over mTLS #283

Merged
merged 1 commit into from
Jul 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions config/config.go
Original file line number Diff line number Diff line change
@@ -78,10 +78,25 @@ var defaultStatusConfig = StatusConfig{
}

type NatsConfig struct {
Host string `yaml:"host"`
Port uint16 `yaml:"port"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Hosts []NatsHost `yaml:"hosts"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
TLSEnabled bool `yaml:"tls_enabled"`
CACerts string `yaml:"ca_certs"`
CAPool *x509.CertPool `yaml:"-"`
ClientAuthCertificate tls.Certificate `yaml:"-"`
TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
}

type NatsHost struct {
Hostname string
Port uint16
}

var defaultNatsConfig = NatsConfig{
Hosts: []NatsHost{{Hostname: "localhost", Port: 42222}},
User: "",
Pass: "",
}

type RoutingApiConfig struct {
@@ -94,13 +109,6 @@ type RoutingApiConfig struct {
TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
}

var defaultNatsConfig = NatsConfig{
Host: "localhost",
Port: 4222,
User: "",
Pass: "",
}

type OAuthConfig struct {
TokenEndpoint string `yaml:"token_endpoint"`
Port int `yaml:"port"`
@@ -181,7 +189,7 @@ type HTTPRewriteResponses struct {

type Config struct {
Status StatusConfig `yaml:"status,omitempty"`
Nats []NatsConfig `yaml:"nats,omitempty"`
Nats NatsConfig `yaml:"nats,omitempty"`
Logging LoggingConfig `yaml:"logging,omitempty"`
Port uint16 `yaml:"port,omitempty"`
Index uint `yaml:"index,omitempty"`
@@ -283,7 +291,7 @@ type Config struct {

var defaultConfig = Config{
Status: defaultStatusConfig,
Nats: []NatsConfig{defaultNatsConfig},
Nats: defaultNatsConfig,
Logging: defaultLoggingConfig,
Port: 8081,
Index: 0,
@@ -399,6 +407,21 @@ func (c *Config) Process() error {
c.RoutingApi.CAPool = certPool
}

if c.Nats.TLSEnabled {
certificate, err := tls.X509KeyPair([]byte(c.Nats.CertChain), []byte(c.Nats.PrivateKey))
if err != nil {
errMsg := fmt.Sprintf("Error loading NATS key pair: %s", err.Error())
return fmt.Errorf(errMsg)
}
c.Nats.ClientAuthCertificate = certificate

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(c.Nats.CACerts)); !ok {
return fmt.Errorf("Error while adding CACerts to gorouter's routing-api cert pool: \n%s\n", c.Nats.CACerts)
}
c.Nats.CAPool = certPool
}

if c.EnableSSL {
switch c.ClientCertificateValidationString {
case "none":
@@ -650,11 +673,11 @@ func convertCipherStringToInt(cipherStrs []string, cipherMap map[string]uint16)

func (c *Config) NatsServers() []string {
var natsServers []string
for _, info := range c.Nats {
for _, host := range c.Nats.Hosts {
uri := url.URL{
Scheme: "nats",
User: url.UserPassword(info.User, info.Pass),
Host: fmt.Sprintf("%s:%d", info.Host, info.Port),
User: url.UserPassword(c.Nats.User, c.Nats.Pass),
Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port),
}
natsServers = append(natsServers, uri.String())
}
@@ -667,7 +690,6 @@ func (c *Config) RoutingApiEnabled() bool {
}

func (c *Config) Initialize(configYAML []byte) error {
c.Nats = []NatsConfig{}
return yaml.Unmarshal(configYAML, &c)
}

95 changes: 75 additions & 20 deletions config/config_test.go
Original file line number Diff line number Diff line change
@@ -116,22 +116,78 @@ endpoint_keep_alive_probe_interval: 500ms
Expect(config.EndpointKeepAliveProbeInterval).To(Equal(500 * time.Millisecond))
})

It("sets nats config", func() {
var b = []byte(`
Context("NATS Config", func() {
It("handles basic nats config", func() {
var b = []byte(`
nats:
- host: remotehost
user: user
pass: pass
hosts:
- hostname: remotehost
port: 4223
user: user
pass: pass
`)
err := config.Initialize(b)
Expect(err).ToNot(HaveOccurred())
err := config.Initialize(b)
Expect(err).ToNot(HaveOccurred())

Expect(config.Nats.User).To(Equal("user"))
Expect(config.Nats.Pass).To(Equal("pass"))
Expect(config.Nats.Hosts).To(HaveLen(1))
Expect(config.Nats.Hosts[0].Hostname).To(Equal("remotehost"))
Expect(config.Nats.Hosts[0].Port).To(Equal(uint16(4223)))
})

Context("when TLSEnabled is set to true", func() {
var (
err error
configSnippet *Config
caCert tls.Certificate
clientPair tls.Certificate
)

createYMLSnippet := func(snippet *Config) []byte {
cfgBytes, err := yaml.Marshal(snippet)
Expect(err).ToNot(HaveOccurred())
return cfgBytes
}

BeforeEach(func() {
caCertChain := test_util.CreateSignedCertWithRootCA(test_util.CertNames{CommonName: "spinach.com"})
clientKeyPEM, clientCertPEM := test_util.CreateKeyPair("potato.com")

caCert, err = tls.X509KeyPair(append(caCertChain.CertPEM, caCertChain.CACertPEM...), caCertChain.PrivKeyPEM)
Expect(err).ToNot(HaveOccurred())
clientPair, err = tls.X509KeyPair(clientCertPEM, clientKeyPEM)
Expect(err).ToNot(HaveOccurred())

configSnippet = &Config{
Nats: NatsConfig{
TLSEnabled: true,
CACerts: fmt.Sprintf("%s%s", caCertChain.CertPEM, caCertChain.CACertPEM),
TLSPem: TLSPem{
CertChain: string(clientCertPEM),
PrivateKey: string(clientKeyPEM),
},
},
}
})

Expect(config.Nats).To(HaveLen(1))
Expect(config.Nats[0].Host).To(Equal("remotehost"))
Expect(config.Nats[0].Port).To(Equal(uint16(4223)))
Expect(config.Nats[0].User).To(Equal("user"))
Expect(config.Nats[0].Pass).To(Equal("pass"))
It("configures TLS", func() {
configBytes := createYMLSnippet(configSnippet)
err = config.Initialize(configBytes)
Expect(err).NotTo(HaveOccurred())
err = config.Process()
Expect(err).NotTo(HaveOccurred())

Expect(config.Nats.CAPool).ToNot(BeNil())
poolSubjects := config.Nats.CAPool.Subjects()
parsedCert, err := x509.ParseCertificate(caCert.Certificate[0])
Expect(err).NotTo(HaveOccurred())
expectedSubject := parsedCert.RawSubject

Expect(string(poolSubjects[0])).To(Equal(string(expectedSubject)))
Expect(config.Nats.ClientAuthCertificate).To(Equal(clientPair))
})
})
})

Context("Suspend Pruning option", func() {
@@ -752,14 +808,13 @@ secure_cookies: false
Describe("NatsServers", func() {
var b = []byte(`
nats:
- host: remotehost
port: 4223
user: user
pass: pass
- host: remotehost2
user: user
pass: pass
hosts:
- hostname: remotehost
port: 4223
user: user2
pass: pass2
- hostname: remotehost2
port: 4224
`)

It("returns a slice of the configured NATS servers", func() {
@@ -768,7 +823,7 @@ nats:

natsServers := config.NatsServers()
Expect(natsServers[0]).To(Equal("nats://user:pass@remotehost:4223"))
Expect(natsServers[1]).To(Equal("nats://user2:pass2@remotehost2:4223"))
Expect(natsServers[1]).To(Equal("nats://user:pass@remotehost2:4224"))
})
})

2 changes: 1 addition & 1 deletion integration/common_integration_test.go
Original file line number Diff line number Diff line change
@@ -254,7 +254,7 @@ func (s *testState) registerAndWait(rm mbus.RegistryMessage) {
func (s *testState) StartGorouter() *Session {
Expect(s.cfg).NotTo(BeNil(), "set up test cfg before calling this function")

s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats[0].Port))
s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats.Hosts[0].Port))
s.natsRunner.Start()

var err error
8 changes: 4 additions & 4 deletions integration/main_test.go
Original file line number Diff line number Diff line change
@@ -1111,14 +1111,14 @@ func hostnameAndPort(url string) (string, int) {
}

func newMessageBus(c *config.Config) (*nats.Conn, error) {
natsMembers := make([]string, len(c.Nats))
natsMembers := make([]string, len(c.Nats.Hosts))
options := nats.DefaultOptions
options.PingInterval = 200 * time.Millisecond
for _, info := range c.Nats {
for _, host := range c.Nats.Hosts {
uri := url.URL{
Scheme: "nats",
User: url.UserPassword(info.User, info.Pass),
Host: fmt.Sprintf("%s:%d", info.Host, info.Port),
User: url.UserPassword(c.Nats.User, c.Nats.Pass),
Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port),
}
natsMembers = append(natsMembers, uri.String())
}
17 changes: 14 additions & 3 deletions mbus/client.go
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@ import (

"code.cloudfoundry.org/gorouter/config"
"code.cloudfoundry.org/gorouter/logger"
"code.cloudfoundry.org/tlsconfig"
"github.com/nats-io/nats.go"
"github.com/uber-go/zap"
)
@@ -54,10 +55,20 @@ func Connect(c *config.Config, reconnected chan<- Signal, l logger.Logger) *nats
}

func natsOptions(l logger.Logger, c *config.Config, natsHost *atomic.Value, reconnected chan<- Signal) nats.Options {
natsServers := c.NatsServers()

options := nats.DefaultOptions
options.Servers = natsServers
options.Servers = c.NatsServers()
if c.Nats.TLSEnabled {
var err error
options.TLSConfig, err = tlsconfig.Build(
tlsconfig.WithInternalServiceDefaults(),
tlsconfig.WithIdentity(c.Nats.ClientAuthCertificate),
).Client(
tlsconfig.WithAuthority(c.Nats.CAPool),
)
if err != nil {
l.Fatal("nats-tls-config-invalid", zap.Object("error", err))
}
}
options.PingInterval = c.NatsClientPingInterval
options.MaxReconnect = -1
notDisconnected := make(chan Signal)
17 changes: 9 additions & 8 deletions test_util/helpers.go
Original file line number Diff line number Diff line change
@@ -240,14 +240,15 @@ func generateConfig(statusPort, proxyPort uint16, natsPorts ...uint16) *config.C
Pass: "pass",
}

c.Nats = []config.NatsConfig{}
for _, natsPort := range natsPorts {
c.Nats = append(c.Nats, config.NatsConfig{
Host: "localhost",
Port: natsPort,
User: "nats",
Pass: "nats",
})
natsHosts := make([]config.NatsHost, len(natsPorts))
for i, natsPort := range natsPorts {
natsHosts[i].Hostname = "localhost"
natsHosts[i].Port = natsPort
}
c.Nats = config.NatsConfig{
User: "nats",
Pass: "nats",
Hosts: natsHosts,
}

c.Logging.Level = "debug"