diff --git a/internal/cmd/cli/login/login_cmd.go b/internal/cmd/cli/login/login_cmd.go index b2a97266d..6d4f395da 100644 --- a/internal/cmd/cli/login/login_cmd.go +++ b/internal/cmd/cli/login/login_cmd.go @@ -368,21 +368,31 @@ func (c *runnerContext) run(cmd *cobra.Command, args []string) error { cfg.SetAddress(c.address) cfg.SetPrivate(c.args.private) - // For CA files that are absolute we need to store only the path, but for those that are relative we need to - // save the content because otherwise we will not be able to use them when the command is executed from a - // different directory. + // Store the CA entries. For regular files we save both the absolute path and the content so that the + // certificate can be reloaded from disk after rotation, with the stored content as a fallback. For directories + // we only store the absolute path so that their contents are scanned on every invocation, allowing files to be + // added or removed inside the directory. for _, caFile := range c.args.caFiles { - if filepath.IsAbs(caFile) { + caPath := filepath.Clean(caFile) + caPath, err = filepath.Abs(caPath) + if err != nil { + return fmt.Errorf("failed to resolve absolute path for '%s': %w", caFile, err) + } + caInfo, err := os.Stat(caPath) + if err != nil { + return fmt.Errorf("CA path '%s' is not accessible: %w", caPath, err) + } + if caInfo.IsDir() { cfg.AddCaFile(config.CaFile{ - Name: caFile, + Name: caPath, }) } else { - caContent, err := os.ReadFile(filepath.Clean(caFile)) + caContent, err := os.ReadFile(caPath) if err != nil { - return fmt.Errorf("failed to read CA file '%s': %w", caFile, err) + return fmt.Errorf("failed to read CA file '%s': %w", caPath, err) } cfg.AddCaFile(config.CaFile{ - Name: caFile, + Name: caPath, Content: string(caContent), }) } diff --git a/internal/config/config_settings.go b/internal/config/config_settings.go index fe5a4650b..f616f9751 100644 --- a/internal/config/config_settings.go +++ b/internal/config/config_settings.go @@ -559,50 +559,68 @@ func (c *Settings) CaPool(ctx context.Context) (result *x509.CertPool, err error } func (c *Settings) createCaPool(ctx context.Context) error { - // Create a temporary directory for the CA files that we have content for. Those will usually be the CA files - // that were specified with relative paths when the configuration was saved. The rest of the CA files, the ones with - // absolute paths, will be loaded from the filesystem. + // Classify configured CA entries into certificate content and directory paths. + // + // Entries with a relative path are ignored because the tool may be invoked from a different working directory. + // This is acceptable because the CLI always calculates absolute paths at configuration time. + // + // Entries with an absolute path and stored content represent regular files whose content was captured by the + // CLI at configuration time. We try to re-read the file from disk so that rotated certificates are picked up + // automatically; if the file is no longer accessible we fall back to the stored content. + // + // Entries with an absolute path but no stored content represent directories. These are passed through to the + // certifiacte pool builder so that their contents are scanned on every invocation, allowing the user to add or + // remove certificate files inside the directory. var ( - contentFiles []CaFile - otherFiles []string + caCerts []any + caFiles []string ) for _, caFile := range c.general.CaFiles { - if caFile.Content != "" { - contentFiles = append(contentFiles, caFile) - } else { - otherFiles = append(otherFiles, caFile.Name) - } - } - contentDir, err := os.MkdirTemp("", "ca-*") - if err != nil { - return fmt.Errorf("failed to create temporary directory for CA files: %w", err) - } - defer func() { - err := os.RemoveAll(contentDir) - if err != nil { - c.logger.ErrorContext( + caPath := filepath.Clean(caFile.Name) + caContent := caFile.Content + if !filepath.IsAbs(caPath) { + c.logger.WarnContext( ctx, - "Failed to remove temporary directory for CA files", - slog.Any("error", err), + "Ignoring CA entry with relative path", + slog.String("file", caPath), ) + continue } - }() - for i, contentFile := range contentFiles { - contentName := fmt.Sprintf("%d-%s", i, filepath.Base(contentFile.Name)) - contentPath := filepath.Join(contentDir, contentName) - err = os.WriteFile(contentPath, []byte(contentFile.Content), 0600) - if err != nil { - return fmt.Errorf("failed to write CA file to temporary directory: %w", err) + if caContent != "" { + caBytes, err := os.ReadFile(caPath) + if err != nil { + c.logger.WarnContext( + ctx, + "CA file is not readable, using stored content", + slog.String("file", caPath), + slog.String("error", err.Error()), + ) + caCerts = append(caCerts, caContent) + } else { + caCerts = append(caCerts, caBytes) + } + } else { + caInfo, err := os.Stat(caPath) + if err != nil || !caInfo.IsDir() { + c.logger.WarnContext( + ctx, + "CA entry without stored content is not an accessible directory", + slog.String("file", caPath), + ) + continue + } + caFiles = append(caFiles, caPath) } } // Create the CA pool: + var err error c.caPool, err = network.NewCertPool(). SetLogger(c.logger). AddSystemFiles(true). AddKubernetesFiles(true). - AddFile(contentDir). - AddFiles(otherFiles...). + AddCertificates(caCerts...). + AddFiles(caFiles...). Build() return err } diff --git a/internal/config/config_settings_test.go b/internal/config/config_settings_test.go index f57499865..c0fb24b83 100644 --- a/internal/config/config_settings_test.go +++ b/internal/config/config_settings_test.go @@ -15,7 +15,12 @@ package config import ( "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" "fmt" + "math/big" "os" "path/filepath" "time" @@ -29,17 +34,11 @@ import ( ) var _ = Describe("Settings", func() { - var ( - ctx context.Context - tmp string - ) + var tmp string BeforeEach(func() { var err error - // Create a context: - ctx = context.Background() - // Create a temporary directory: tmp, err = os.MkdirTemp("", "*.test") Expect(err).ToNot(HaveOccurred()) @@ -69,7 +68,7 @@ var _ = Describe("Settings", func() { keyring.MockInit() }) - It("Loads general settings from the config file", func() { + It("Loads general settings from the config file", func(ctx context.Context) { // Create the config file: file := filepath.Join(tmp, "config.json") content := []byte(`{ @@ -103,7 +102,7 @@ var _ = Describe("Settings", func() { Expect(settings.Private()).To(BeTrue()) }) - It("Saves general settings in the config file", func() { + It("Saves general settings in the config file", func(ctx context.Context) { // Create the settings and save them: settings, err := NewSettings(). SetLogger(logger). @@ -135,7 +134,7 @@ var _ = Describe("Settings", func() { }`)) }) - It("Returns empty settings when no file exists", func() { + It("Returns empty settings when no file exists", func(ctx context.Context) { settings, err := NewSettings(). SetLogger(logger). SetDir(tmp). @@ -152,7 +151,7 @@ var _ = Describe("Settings", func() { Expect(settings.Private()).To(BeFalse()) }) - It("Returns nil token when no access token is present", func() { + It("Returns nil token when no access token is present", func(ctx context.Context) { settings, err := NewSettings(). SetLogger(logger). SetDir(tmp). @@ -164,7 +163,7 @@ var _ = Describe("Settings", func() { Expect(token).To(BeNil()) }) - It("Skips save when tokens have not changed", func() { + It("Skips save when tokens have not changed", func(ctx context.Context) { cfg, err := NewSettings(). SetLogger(logger). SetDir(tmp). @@ -189,7 +188,7 @@ var _ = Describe("Settings", func() { keyring.MockInit() }) - It("Saves secrets in the keyring, not in the config file", func() { + It("Saves secrets in the keyring, not in the config file", func(ctx context.Context) { // Create the settings and save them: settings, err := NewSettings(). SetLogger(logger). @@ -226,7 +225,7 @@ var _ = Describe("Settings", func() { Expect(content).To(MatchJSON(`{}`)) }) - It("Persists tokens when saving through the token store", func() { + It("Persists tokens when saving through the token store", func(ctx context.Context) { // Create the settings and save them: settings, err := NewSettings(). SetLogger(logger). @@ -261,7 +260,7 @@ var _ = Describe("Settings", func() { keyring.MockInitWithError(fmt.Errorf("keyring backend not available")) }) - It("Saves secrets in the secrets file, not in the keyring", func() { + It("Saves secrets in the secrets file, not in the keyring", func(ctx context.Context) { // Create the settings and save them: settings, err := NewSettings(). SetLogger(logger). @@ -293,7 +292,7 @@ var _ = Describe("Settings", func() { }`)) }) - It("Persists tokens when saving through the token store", func() { + It("Persists tokens when saving through the token store", func(ctx context.Context) { // Create the settings and save them: settings, err := NewSettings(). SetLogger(logger). @@ -324,6 +323,128 @@ var _ = Describe("Settings", func() { }) }) + Describe("CA pool creation", func() { + var pemBytes []byte + + BeforeEach(func() { + keyring.MockInit() + + // Generate a test CA certificate: + key, err := rsa.GenerateKey(rand.Reader, 2048) + Expect(err).ToNot(HaveOccurred()) + now := time.Now() + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: now.Add(-time.Hour), + NotAfter: now.Add(time.Hour), + IsCA: true, + BasicConstraintsValid: true, + } + derBytes, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + Expect(err).ToNot(HaveOccurred()) + pemBytes = pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: derBytes, + }) + }) + + It("Ignores entries with relative paths", func(ctx context.Context) { + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: "relative/path/ca.pem", + Content: string(pemBytes), + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Reads file from disk when absolute path and content are both set", func(ctx context.Context) { + pemFile := filepath.Join(tmp, "ca.pem") + err := os.WriteFile(pemFile, pemBytes, 0600) + Expect(err).ToNot(HaveOccurred()) + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: pemFile, + Content: string(pemBytes), + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Falls back to stored content when file is not readable", func(ctx context.Context) { + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: "/nonexistent/ca.pem", + Content: string(pemBytes), + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Adds CA files when absolute path without content is a directory", func(ctx context.Context) { + caDir := filepath.Join(tmp, "ca") + err := os.MkdirAll(caDir, 0755) + Expect(err).ToNot(HaveOccurred()) + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: caDir, + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Ignores CA files when absolute path without content does not exist", func(ctx context.Context) { + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: "/nonexistent/directory", + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Ignores absolute path without content that is a regular file", func(ctx context.Context) { + pemFile := filepath.Join(tmp, "not-a-dir.pem") + err := os.WriteFile(pemFile, pemBytes, 0600) + Expect(err).ToNot(HaveOccurred()) + settings, err := NewSettings(). + SetLogger(logger). + SetDir(tmp). + Build() + Expect(err).ToNot(HaveOccurred()) + settings.AddCaFile(CaFile{ + Name: pemFile, + }) + pool, err := settings.CaPool(ctx) + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + }) + Describe("Armed check", func() { It("Returns false when settings are nil", func() { var settings *Settings diff --git a/internal/network/cert_pool.go b/internal/network/cert_pool.go index 6541f1b4d..1cc88b865 100644 --- a/internal/network/cert_pool.go +++ b/internal/network/cert_pool.go @@ -33,6 +33,7 @@ type CertPoolBuilder struct { root string files []string exts []string + certs []any } // NewCertPool creates a builder that can then used to configure and create a certificate pool. @@ -110,6 +111,20 @@ func (b *CertPoolBuilder) AddExtensions(values ...string) *CertPoolBuilder { return b } +// AddCertificate adds a certificate to the pool. The value can be a string or a slice of bytes containing a PEM +// encoded certificate, or a *x509.Certificate object. +func (b *CertPoolBuilder) AddCertificate(value any) *CertPoolBuilder { + b.certs = append(b.certs, value) + return b +} + +// AddCertificates adds multiple certificates to the pool. Each value can be a string or a slice of bytes containing +// a PEM encoded certificate, or a *x509.Certificate object. +func (b *CertPoolBuilder) AddCertificates(values ...any) *CertPoolBuilder { + b.certs = append(b.certs, values...) + return b +} + // Build uses the data stored in the builder to create a new certificate pool. func (b *CertPoolBuilder) Build() (result *x509.CertPool, err error) { // Check parameters: @@ -140,12 +155,38 @@ func (b *CertPoolBuilder) Build() (result *x509.CertPool, err error) { } } - // Load configured CA files: + // Load configured files: err = b.loadConfiguredFiles(pool) if err != nil { return } + // Load configured certificates: + for _, cert := range b.certs { + switch cert := cert.(type) { + case string: + ok := pool.AppendCertsFromPEM([]byte(cert)) + if !ok { + err = fmt.Errorf("failed to add certificate to pool: %w", err) + return + } + case []byte: + ok := pool.AppendCertsFromPEM(cert) + if !ok { + err = fmt.Errorf("failed to add certificate to pool: %w", err) + return + } + case *x509.Certificate: + pool.AddCert(cert) + default: + err = fmt.Errorf( + "invalid certificate type '%T', should be 'string', '[]byte' or '*x509.Certificate'", + cert, + ) + return + } + } + result = pool return } diff --git a/internal/network/cert_pool_test.go b/internal/network/cert_pool_test.go index 3f3903360..a65597d67 100644 --- a/internal/network/cert_pool_test.go +++ b/internal/network/cert_pool_test.go @@ -241,6 +241,131 @@ var _ = Describe("Certificate pool", func() { Expect(err.Error()).To(ContainSubstring("doesn't contain any CA certificate")) Expect(pool).To(BeNil()) }) + + It("Can be created with a single certificate from bytes", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + caPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate(caPem). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Can be created with a single certificate from a string", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + caPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate(string(caPem)). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Can be created with a single certificate from X.509 object", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate(myCerts.caCert). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Can be created with multiple certificates from bytes", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + myPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + yourPem, err := os.ReadFile(yourCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificates(myPem, yourPem). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Can be created with multiple certificates from strings", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + myPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + yourPem, err := os.ReadFile(yourCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificates(string(myPem), string(yourPem)). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Can be created with multiple certificates from X.509 objects", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificates(myCerts.caCert, yourCerts.caCert). + Build() + Expect(err).ToNot(HaveOccurred()) + Expect(pool).ToNot(BeNil()) + }) + + It("Fails with invalid byte content", func() { + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate([]byte("junk")). + Build() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to add certificate")) + Expect(pool).To(BeNil()) + }) + + It("Fails with invalid string content", func() { + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate("not a valid certificate"). + Build() + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("failed to add certificate")) + Expect(pool).To(BeNil()) + }) + + It("Fails with unsupported type", func() { + pool, err := NewCertPool(). + SetLogger(logger). + AddCertificate(12345). + Build() + Expect(err).To(HaveOccurred()) + Expect(err).To(MatchError( + "invalid certificate type 'int', should be 'string', '[]byte' or '*x509.Certificate'", + )) + Expect(pool).To(BeNil()) + }) }) Describe("Behavior", func() { @@ -449,6 +574,154 @@ var _ = Describe("Certificate pool", func() { Expect(err).ToNot(HaveOccurred()) }) + It("Can verify certificates loaded from a byte slice", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + caPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificate(caPem). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificate using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + chains, err := myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + Expect(chains).ToNot(BeEmpty()) + }) + + It("Can verify certificates loaded from a string", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + caPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificate(string(caPem)). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificate using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + chains, err := myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + Expect(chains).ToNot(BeEmpty()) + }) + + It("Can verify certificates loaded from an X.509 object", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificate(myCerts.caCert). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificate using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + chains, err := myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + Expect(chains).ToNot(BeEmpty()) + }) + + It("Can verify certificates from multiple CAs loaded from byte slices", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + myPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + yourPem, err := os.ReadFile(yourCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificates(myPem, yourPem). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificates using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + _, err = myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + _, err = yourCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + }) + + It("Can verify certificates from multiple CAs loaded from X.509 objects", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificates(myCerts.caCert, yourCerts.caCert). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificates using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + _, err = myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + _, err = yourCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + }) + + It("Can mix byte slices and X.509 objects in the same pool", func() { + // Create the certificates: + myCerts := makeCerts("My CA") + yourCerts := makeCerts("Your CA") + myPem, err := os.ReadFile(myCerts.caCertFile) + Expect(err).ToNot(HaveOccurred()) + + // Create the pool: + pool, err := NewCertPool(). + SetLogger(logger). + AddSystemFiles(false). + AddKubernetesFiles(false). + AddCertificates(myPem, yourCerts.caCert). + Build() + Expect(err).ToNot(HaveOccurred()) + + // Verify the TLS certificates using the pool: + opts := x509.VerifyOptions{ + Roots: pool, + } + _, err = myCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + _, err = yourCerts.tlsCert.Verify(opts) + Expect(err).ToNot(HaveOccurred()) + }) + It("Doesn't load system certificates if explicitly disabled", func() { // Create the pool: pool, err := NewCertPool().