Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for talking to NATS over mTLS #283

Merged
merged 1 commit into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 39 additions & 17 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,25 @@ var defaultStatusConfig = StatusConfig{
}

type NatsConfig struct {
Host string `yaml:"host"`
Port uint16 `yaml:"port"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
Hosts []NatsHost `yaml:"hosts"`
User string `yaml:"user"`
Pass string `yaml:"pass"`
TLSEnabled bool `yaml:"tls_enabled"`
CACerts string `yaml:"ca_certs"`
CAPool *x509.CertPool `yaml:"-"`
ClientAuthCertificate tls.Certificate `yaml:"-"`
TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
}

type NatsHost struct {
Hostname string
Port uint16
}

var defaultNatsConfig = NatsConfig{
Hosts: []NatsHost{{Hostname: "localhost", Port: 42222}},
User: "",
Pass: "",
}

type RoutingApiConfig struct {
Expand All @@ -94,13 +109,6 @@ type RoutingApiConfig struct {
TLSPem `yaml:",inline"` // embed to get cert_chain and private_key for client authentication
}

var defaultNatsConfig = NatsConfig{
Host: "localhost",
Port: 4222,
User: "",
Pass: "",
}

type OAuthConfig struct {
TokenEndpoint string `yaml:"token_endpoint"`
Port int `yaml:"port"`
Expand Down Expand Up @@ -181,7 +189,7 @@ type HTTPRewriteResponses struct {

type Config struct {
Status StatusConfig `yaml:"status,omitempty"`
Nats []NatsConfig `yaml:"nats,omitempty"`
Nats NatsConfig `yaml:"nats,omitempty"`
Logging LoggingConfig `yaml:"logging,omitempty"`
Port uint16 `yaml:"port,omitempty"`
Index uint `yaml:"index,omitempty"`
Expand Down Expand Up @@ -283,7 +291,7 @@ type Config struct {

var defaultConfig = Config{
Status: defaultStatusConfig,
Nats: []NatsConfig{defaultNatsConfig},
Nats: defaultNatsConfig,
Logging: defaultLoggingConfig,
Port: 8081,
Index: 0,
Expand Down Expand Up @@ -399,6 +407,21 @@ func (c *Config) Process() error {
c.RoutingApi.CAPool = certPool
}

if c.Nats.TLSEnabled {
certificate, err := tls.X509KeyPair([]byte(c.Nats.CertChain), []byte(c.Nats.PrivateKey))
if err != nil {
errMsg := fmt.Sprintf("Error loading NATS key pair: %s", err.Error())
return fmt.Errorf(errMsg)
}
c.Nats.ClientAuthCertificate = certificate

certPool := x509.NewCertPool()
if ok := certPool.AppendCertsFromPEM([]byte(c.Nats.CACerts)); !ok {
return fmt.Errorf("Error while adding CACerts to gorouter's routing-api cert pool: \n%s\n", c.Nats.CACerts)
}
c.Nats.CAPool = certPool
}

if c.EnableSSL {
switch c.ClientCertificateValidationString {
case "none":
Expand Down Expand Up @@ -650,11 +673,11 @@ func convertCipherStringToInt(cipherStrs []string, cipherMap map[string]uint16)

func (c *Config) NatsServers() []string {
var natsServers []string
for _, info := range c.Nats {
for _, host := range c.Nats.Hosts {
uri := url.URL{
Scheme: "nats",
User: url.UserPassword(info.User, info.Pass),
Host: fmt.Sprintf("%s:%d", info.Host, info.Port),
User: url.UserPassword(c.Nats.User, c.Nats.Pass),
Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port),
}
natsServers = append(natsServers, uri.String())
}
Expand All @@ -667,7 +690,6 @@ func (c *Config) RoutingApiEnabled() bool {
}

func (c *Config) Initialize(configYAML []byte) error {
c.Nats = []NatsConfig{}
return yaml.Unmarshal(configYAML, &c)
}

Expand Down
95 changes: 75 additions & 20 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,22 +116,78 @@ endpoint_keep_alive_probe_interval: 500ms
Expect(config.EndpointKeepAliveProbeInterval).To(Equal(500 * time.Millisecond))
})

It("sets nats config", func() {
var b = []byte(`
Context("NATS Config", func() {
It("handles basic nats config", func() {
var b = []byte(`
nats:
- host: remotehost
user: user
pass: pass
hosts:
- hostname: remotehost
port: 4223
user: user
pass: pass
`)
err := config.Initialize(b)
Expect(err).ToNot(HaveOccurred())
err := config.Initialize(b)
Expect(err).ToNot(HaveOccurred())

Expect(config.Nats.User).To(Equal("user"))
Expect(config.Nats.Pass).To(Equal("pass"))
Expect(config.Nats.Hosts).To(HaveLen(1))
Expect(config.Nats.Hosts[0].Hostname).To(Equal("remotehost"))
Expect(config.Nats.Hosts[0].Port).To(Equal(uint16(4223)))
})

Context("when TLSEnabled is set to true", func() {
var (
err error
configSnippet *Config
caCert tls.Certificate
clientPair tls.Certificate
)

createYMLSnippet := func(snippet *Config) []byte {
cfgBytes, err := yaml.Marshal(snippet)
Expect(err).ToNot(HaveOccurred())
return cfgBytes
}

BeforeEach(func() {
caCertChain := test_util.CreateSignedCertWithRootCA(test_util.CertNames{CommonName: "spinach.com"})
clientKeyPEM, clientCertPEM := test_util.CreateKeyPair("potato.com")

caCert, err = tls.X509KeyPair(append(caCertChain.CertPEM, caCertChain.CACertPEM...), caCertChain.PrivKeyPEM)
Expect(err).ToNot(HaveOccurred())
clientPair, err = tls.X509KeyPair(clientCertPEM, clientKeyPEM)
Expect(err).ToNot(HaveOccurred())

configSnippet = &Config{
Nats: NatsConfig{
TLSEnabled: true,
CACerts: fmt.Sprintf("%s%s", caCertChain.CertPEM, caCertChain.CACertPEM),
TLSPem: TLSPem{
CertChain: string(clientCertPEM),
PrivateKey: string(clientKeyPEM),
},
},
}
})

Expect(config.Nats).To(HaveLen(1))
Expect(config.Nats[0].Host).To(Equal("remotehost"))
Expect(config.Nats[0].Port).To(Equal(uint16(4223)))
Expect(config.Nats[0].User).To(Equal("user"))
Expect(config.Nats[0].Pass).To(Equal("pass"))
It("configures TLS", func() {
configBytes := createYMLSnippet(configSnippet)
err = config.Initialize(configBytes)
Expect(err).NotTo(HaveOccurred())
err = config.Process()
Expect(err).NotTo(HaveOccurred())

Expect(config.Nats.CAPool).ToNot(BeNil())
poolSubjects := config.Nats.CAPool.Subjects()
parsedCert, err := x509.ParseCertificate(caCert.Certificate[0])
Expect(err).NotTo(HaveOccurred())
expectedSubject := parsedCert.RawSubject

Expect(string(poolSubjects[0])).To(Equal(string(expectedSubject)))
Expect(config.Nats.ClientAuthCertificate).To(Equal(clientPair))
})
})
})

Context("Suspend Pruning option", func() {
Expand Down Expand Up @@ -752,14 +808,13 @@ secure_cookies: false
Describe("NatsServers", func() {
var b = []byte(`
nats:
- host: remotehost
port: 4223
user: user
pass: pass
- host: remotehost2
user: user
pass: pass
hosts:
- hostname: remotehost
port: 4223
user: user2
pass: pass2
- hostname: remotehost2
port: 4224
`)

It("returns a slice of the configured NATS servers", func() {
Expand All @@ -768,7 +823,7 @@ nats:

natsServers := config.NatsServers()
Expect(natsServers[0]).To(Equal("nats://user:pass@remotehost:4223"))
Expect(natsServers[1]).To(Equal("nats://user2:pass2@remotehost2:4223"))
Expect(natsServers[1]).To(Equal("nats://user:pass@remotehost2:4224"))
})
})

Expand Down
2 changes: 1 addition & 1 deletion integration/common_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func (s *testState) registerAndWait(rm mbus.RegistryMessage) {
func (s *testState) StartGorouter() *Session {
Expect(s.cfg).NotTo(BeNil(), "set up test cfg before calling this function")

s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats[0].Port))
s.natsRunner = test_util.NewNATSRunner(int(s.cfg.Nats.Hosts[0].Port))
s.natsRunner.Start()

var err error
Expand Down
8 changes: 4 additions & 4 deletions integration/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1111,14 +1111,14 @@ func hostnameAndPort(url string) (string, int) {
}

func newMessageBus(c *config.Config) (*nats.Conn, error) {
natsMembers := make([]string, len(c.Nats))
natsMembers := make([]string, len(c.Nats.Hosts))
options := nats.DefaultOptions
options.PingInterval = 200 * time.Millisecond
for _, info := range c.Nats {
for _, host := range c.Nats.Hosts {
uri := url.URL{
Scheme: "nats",
User: url.UserPassword(info.User, info.Pass),
Host: fmt.Sprintf("%s:%d", info.Host, info.Port),
User: url.UserPassword(c.Nats.User, c.Nats.Pass),
Host: fmt.Sprintf("%s:%d", host.Hostname, host.Port),
}
natsMembers = append(natsMembers, uri.String())
}
Expand Down
17 changes: 14 additions & 3 deletions mbus/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"code.cloudfoundry.org/gorouter/config"
"code.cloudfoundry.org/gorouter/logger"
"code.cloudfoundry.org/tlsconfig"
"github.com/nats-io/nats.go"
"github.com/uber-go/zap"
)
Expand Down Expand Up @@ -54,10 +55,20 @@ func Connect(c *config.Config, reconnected chan<- Signal, l logger.Logger) *nats
}

func natsOptions(l logger.Logger, c *config.Config, natsHost *atomic.Value, reconnected chan<- Signal) nats.Options {
natsServers := c.NatsServers()

options := nats.DefaultOptions
options.Servers = natsServers
options.Servers = c.NatsServers()
if c.Nats.TLSEnabled {
var err error
options.TLSConfig, err = tlsconfig.Build(
tlsconfig.WithInternalServiceDefaults(),
tlsconfig.WithIdentity(c.Nats.ClientAuthCertificate),
).Client(
tlsconfig.WithAuthority(c.Nats.CAPool),
)
if err != nil {
l.Fatal("nats-tls-config-invalid", zap.Object("error", err))
}
}
options.PingInterval = c.NatsClientPingInterval
options.MaxReconnect = -1
notDisconnected := make(chan Signal)
Expand Down
17 changes: 9 additions & 8 deletions test_util/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,15 @@ func generateConfig(statusPort, proxyPort uint16, natsPorts ...uint16) *config.C
Pass: "pass",
}

c.Nats = []config.NatsConfig{}
for _, natsPort := range natsPorts {
c.Nats = append(c.Nats, config.NatsConfig{
Host: "localhost",
Port: natsPort,
User: "nats",
Pass: "nats",
})
natsHosts := make([]config.NatsHost, len(natsPorts))
for i, natsPort := range natsPorts {
natsHosts[i].Hostname = "localhost"
natsHosts[i].Port = natsPort
}
c.Nats = config.NatsConfig{
User: "nats",
Pass: "nats",
Hosts: natsHosts,
}

c.Logging.Level = "debug"
Expand Down