diff --git a/README.md b/README.md index 4167010..69d7736 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ environment variables that you can set. | Variable Name | Description | Default Value | |-----------------------------|---------------------------------------------------------|---------------| | `TLS_DOMAIN` | Comma-separated list of domain names to use for TLS provisioning. If not set, TLS will be disabled. | None | +| `TLS_LOCAL` | Whether to use a self-signed certificate authority for TLS certificate provisioning. | Disabled | | `TARGET_PORT` | The port that your Puma server should run on. Thruster will set `PORT` to this value when starting your server. | 3000 | | `CACHE_SIZE` | The size of the HTTP cache in bytes. | 64MB | | `MAX_CACHE_ITEM_SIZE` | The maximum size of a single item in the HTTP cache in bytes. | 1MB | diff --git a/internal/autocert_tls_provider.go b/internal/autocert_tls_provider.go new file mode 100644 index 0000000..2993229 --- /dev/null +++ b/internal/autocert_tls_provider.go @@ -0,0 +1,59 @@ +package internal + +import ( + "crypto/tls" + "encoding/base64" + "log/slog" + "net/http" + + "golang.org/x/crypto/acme" + "golang.org/x/crypto/acme/autocert" +) + +type AutocertTLSProvider struct { + manager *autocert.Manager +} + +func NewAutocertTLSProvider(storagePath string, domains []string, acmeDirectoryURL string, eabKID string, eabHMACKey string) TLSProvider { + client := &acme.Client{DirectoryURL: acmeDirectoryURL} + binding := createExternalAccountBinding(eabKID, eabHMACKey) + + slog.Debug("TLS: initializing autocert", "directory", client.DirectoryURL, "using_eab", binding != nil) + + manager := &autocert.Manager{ + Cache: autocert.DirCache(storagePath), + Client: client, + ExternalAccountBinding: binding, + HostPolicy: autocert.HostWhitelist(domains...), + Prompt: autocert.AcceptTOS, + } + + return &AutocertTLSProvider{ + manager: manager, + } +} + +func (p *AutocertTLSProvider) HTTPHandler(h http.Handler) http.Handler { + return p.manager.HTTPHandler(h) +} + +func (p *AutocertTLSProvider) TLSConfig() *tls.Config { + return p.manager.TLSConfig() +} + +func createExternalAccountBinding(kid string, hmacKey string) *acme.ExternalAccountBinding { + if kid == "" || hmacKey == "" { + return nil + } + + key, err := base64.RawURLEncoding.DecodeString(hmacKey) + if err != nil { + slog.Error("Error decoding EAB_HMACKey", "error", err) + return nil + } + + return &acme.ExternalAccountBinding{ + KID: kid, + Key: key, + } +} diff --git a/internal/autocert_tls_provider_test.go b/internal/autocert_tls_provider_test.go new file mode 100644 index 0000000..8ead916 --- /dev/null +++ b/internal/autocert_tls_provider_test.go @@ -0,0 +1,97 @@ +package internal + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewAutocertTLSProvider(t *testing.T) { + tmpDir := t.TempDir() + domains := []string{"example.com", "www.example.com"} + acmeURL := "https://acme-staging-v02.api.letsencrypt.org/directory" + + provider := NewAutocertTLSProvider(tmpDir, domains, acmeURL, "", "") + + require.NotNil(t, provider) +} + +func TestNewAutocertTLSProvider_WithEAB(t *testing.T) { + tmpDir := t.TempDir() + domains := []string{"example.com"} + acmeURL := "https://acme.zerossl.com/v2/DV90" + eabKID := "test-kid" + eabHMACKey := "dGVzdC1obWFjLWtleQ" // base64 encoded "test-hmac-key" + + provider := NewAutocertTLSProvider(tmpDir, domains, acmeURL, eabKID, eabHMACKey) + + require.NotNil(t, provider) +} + +func TestAutocertTLSProvider_HTTPHandler(t *testing.T) { + tmpDir := t.TempDir() + domains := []string{"example.com"} + acmeURL := "https://acme-staging-v02.api.letsencrypt.org/directory" + provider := NewAutocertTLSProvider(tmpDir, domains, acmeURL, "", "") + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test")) + require.NoError(t, err) + }) + + wrapped := provider.HTTPHandler(handler) + require.NotNil(t, wrapped) +} + +func TestAutocertTLSProvider_TLSConfig(t *testing.T) { + tmpDir := t.TempDir() + domains := []string{"example.com", "www.example.com"} + acmeURL := "https://acme-staging-v02.api.letsencrypt.org/directory" + provider := NewAutocertTLSProvider(tmpDir, domains, acmeURL, "", "") + + config := provider.TLSConfig() + + require.NotNil(t, config) + assert.NotNil(t, config.GetCertificate) + assert.Contains(t, config.NextProtos, "h2") + assert.Contains(t, config.NextProtos, "http/1.1") + assert.Contains(t, config.NextProtos, "acme-tls/1") +} + +func TestCreateExternalAccountBinding_ValidBase64(t *testing.T) { + kid := "test-kid" + hmacKey := "dGVzdC1obWFjLWtleQ" // base64 encoded "test-hmac-key" + + binding := createExternalAccountBinding(kid, hmacKey) + + require.NotNil(t, binding) + assert.Equal(t, kid, binding.KID) + assert.Equal(t, []byte("test-hmac-key"), binding.Key) +} + +func TestCreateExternalAccountBinding_InvalidBase64(t *testing.T) { + kid := "test-kid" + hmacKey := "not-valid-base64!!!" // Invalid base64 + + binding := createExternalAccountBinding(kid, hmacKey) + + // Should return nil on invalid base64 + assert.Nil(t, binding) +} + +func TestCreateExternalAccountBinding_EmptyInputs(t *testing.T) { + // Both empty + binding := createExternalAccountBinding("", "") + assert.Nil(t, binding) + + // Only KID + binding = createExternalAccountBinding("test-kid", "") + assert.Nil(t, binding) + + // Only HMAC key + binding = createExternalAccountBinding("", "dGVzdC1obWFjLWtleQ") + assert.Nil(t, binding) +} diff --git a/internal/config.go b/internal/config.go index fd2887d..cfc191a 100644 --- a/internal/config.go +++ b/internal/config.go @@ -56,6 +56,7 @@ type Config struct { MaxRequestBody int TLSDomains []string + TLSLocal bool ACMEDirectoryURL string EAB_KID string EAB_HMACKey string @@ -100,6 +101,7 @@ func NewConfig() (*Config, error) { MaxRequestBody: getEnvInt("MAX_REQUEST_BODY", defaultMaxRequestBody), TLSDomains: getEnvStrings("TLS_DOMAIN", []string{}), + TLSLocal: getEnvBool("TLS_LOCAL", false), ACMEDirectoryURL: getEnvString("ACME_DIRECTORY", defaultACMEDirectoryURL), EAB_KID: getEnvString("EAB_KID", ""), EAB_HMACKey: getEnvString("EAB_HMAC_KEY", ""), @@ -124,7 +126,7 @@ func NewConfig() (*Config, error) { } func (c *Config) HasTLS() bool { - return len(c.TLSDomains) > 0 + return len(c.TLSDomains) > 0 || c.TLSLocal } func findEnv(key string) (string, bool) { diff --git a/internal/config_test.go b/internal/config_test.go index 5dc1d9b..f718e3a 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -93,6 +93,20 @@ func TestConfig_tls(t *testing.T) { assert.False(t, c.HasTLS()) assert.False(t, c.ForwardHeaders) }) + + t.Run("with TLS_LOCAL", func(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "TLS_DOMAIN", "") + usingEnvVar(t, "TLS_LOCAL", "true") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Equal(t, []string{}, c.TLSDomains) + assert.True(t, c.HasTLS()) + assert.False(t, c.ForwardHeaders) + assert.True(t, c.TLSLocal) + }) } func TestConfig_defaults(t *testing.T) { @@ -106,6 +120,7 @@ func TestConfig_defaults(t *testing.T) { assert.Equal(t, defaultCacheSize, c.CacheSizeBytes) assert.Equal(t, slog.LevelInfo, c.LogLevel) assert.Equal(t, false, c.H2CEnabled) + assert.Equal(t, false, c.TLSLocal) } func TestConfig_override_defaults_with_env_vars(t *testing.T) { @@ -119,6 +134,7 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory") usingEnvVar(t, "LOG_REQUESTS", "false") usingEnvVar(t, "H2C_ENABLED", "true") + usingEnvVar(t, "TLS_LOCAL", "true") usingEnvVar(t, "GZIP_COMPRESSION_DISABLE_ON_AUTH", "true") usingEnvVar(t, "GZIP_COMPRESSION_JITTER", "64") @@ -134,6 +150,7 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL) assert.Equal(t, false, c.LogRequests) assert.Equal(t, true, c.H2CEnabled) + assert.Equal(t, true, c.TLSLocal) assert.Equal(t, true, c.GzipCompressionDisableOnAuth) assert.Equal(t, 64, c.GzipCompressionJitter) } @@ -147,6 +164,7 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { usingEnvVar(t, "THRUSTER_DEBUG", "1") usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0") usingEnvVar(t, "THRUSTER_H2C_ENABLED", "1") + usingEnvVar(t, "THRUSTER_TLS_LOCAL", "1") c, err := NewConfig() require.NoError(t, err) @@ -158,6 +176,7 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, false, c.LogRequests) assert.Equal(t, true, c.H2CEnabled) + assert.Equal(t, true, c.TLSLocal) } func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing.T) { diff --git a/internal/local_tls_provider.go b/internal/local_tls_provider.go new file mode 100644 index 0000000..35d6bc3 --- /dev/null +++ b/internal/local_tls_provider.go @@ -0,0 +1,192 @@ +package internal + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "log/slog" + "math/big" + "net" + "net/http" + "os" + "time" + + "golang.org/x/net/idna" +) + +type localTLSProvider struct { + storagePath string +} + +func NewLocalTLSProvider(storagePath string) TLSProvider { + return &localTLSProvider{ + storagePath: storagePath, + } +} + +func (p *localTLSProvider) HTTPHandler(h http.Handler) http.Handler { + return h +} + +func (p *localTLSProvider) TLSConfig() *tls.Config { + return &tls.Config{ + GetCertificate: p.getCertificate, + NextProtos: []string{ + "h2", "http/1.1", // enable HTTP/2 + }, + } +} + +func (p *localTLSProvider) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + name := hello.ServerName + if name == "" { + return nil, errors.New("thruster/local_tls: missing server name") + } + + name, err := idna.Lookup.ToASCII(name) + if err != nil { + return nil, errors.New("thruster/local_tls: server name contains invalid character") + } + + keyUsage := x509.KeyUsageDigitalSignature + keyUsage |= x509.KeyUsageKeyEncipherment + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Thruster Local"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 10 * 24 * time.Hour), + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + if ip := net.ParseIP(name); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, name) + } + + authority, err := p.getAuthority() + if err != nil { + return nil, err + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + authcert, err := x509.ParseCertificate(authority.Certificate[0]) + if err != nil { + return nil, err + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, authcert, &priv.PublicKey, authority.PrivateKey) + if err != nil { + return nil, err + } + + cert := &tls.Certificate{ + Certificate: [][]byte{authority.Certificate[0], derBytes}, + PrivateKey: authority.PrivateKey, + } + + slog.Debug("TLS: issued local certificate for", "name", name) + + return cert, nil +} + +func (p *localTLSProvider) getAuthority() (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(fmt.Sprintf("%s/authority.crt", p.storagePath), fmt.Sprintf("%s/authority.pem", p.storagePath)) + if err == nil { + return &cert, nil + } + + err = os.MkdirAll(p.storagePath, 0750) + if err != nil { + return nil, err + } + + keyUsage := x509.KeyUsageDigitalSignature + keyUsage |= x509.KeyUsageKeyEncipherment + keyUsage |= x509.KeyUsageCertSign + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return nil, err + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Thruster Local CA"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 10 * 24 * time.Hour), + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + IsCA: true, + } + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, err + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return nil, err + } + + certOut, err := os.Create(fmt.Sprintf("%s/authority.crt", p.storagePath)) + if err != nil { + return nil, err + } + + if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + return nil, err + } + + if err := certOut.Close(); err != nil { + return nil, err + } + + keyOut, err := os.Create(fmt.Sprintf("%s/authority.pem", p.storagePath)) + if err != nil { + return nil, err + } + + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + return nil, err + } + + if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + return nil, err + } + + if err := keyOut.Close(); err != nil { + return nil, err + } + + cer := tls.Certificate{ + Certificate: [][]byte{derBytes}, + PrivateKey: priv, + } + + return &cer, nil +} diff --git a/internal/local_tls_provider_test.go b/internal/local_tls_provider_test.go new file mode 100644 index 0000000..390fd90 --- /dev/null +++ b/internal/local_tls_provider_test.go @@ -0,0 +1,242 @@ +package internal + +import ( + "crypto/tls" + "crypto/x509" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestLocalTLSProvider_HTTPHandler(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("test response")) + require.NoError(t, err) + }) + + wrapped := provider.HTTPHandler(handler) + + // HTTPHandler should just pass through without modification for local TLS + req := httptest.NewRequest("GET", "http://example.com/", nil) + rec := httptest.NewRecorder() + wrapped.ServeHTTP(rec, req) + + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "test response", rec.Body.String()) +} + +func TestLocalTLSProvider_TLSConfig(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + config := provider.TLSConfig() + + require.NotNil(t, config) + assert.NotNil(t, config.GetCertificate) + assert.Contains(t, config.NextProtos, "h2") + assert.Contains(t, config.NextProtos, "http/1.1") +} + +func TestLocalTLSProvider_WithDomainName(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + require.NoError(t, err) + require.NotNil(t, cert) + assert.NotNil(t, cert.PrivateKey) + assert.NotEmpty(t, cert.Certificate) + + // Verify the certificate has the correct DNS name + x509Cert, err := x509.ParseCertificate(cert.Certificate[1]) + require.NoError(t, err) + assert.Contains(t, x509Cert.DNSNames, "example.com") + assert.Equal(t, "Thruster Local", x509Cert.Subject.Organization[0]) +} + +func TestLocalTLSProvider_WithIPAddress(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "127.0.0.1", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify the certificate has the correct IP address + x509Cert, err := x509.ParseCertificate(cert.Certificate[1]) + require.NoError(t, err) + assert.Len(t, x509Cert.IPAddresses, 1) + assert.Equal(t, "127.0.0.1", x509Cert.IPAddresses[0].String()) +} + +func TestLocalTLSProvider_MissingServerName(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + assert.Error(t, err) + assert.Nil(t, cert) + assert.Contains(t, err.Error(), "missing server name") +} + +func TestLocalTLSProvider_InvalidServerName(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "invalid\x00name", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + assert.Error(t, err) + assert.Nil(t, cert) +} + +func TestLocalTLSProvider_CreatesNewCA(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify CA files were created + assert.FileExists(t, filepath.Join(tmpDir, "authority.crt")) + assert.FileExists(t, filepath.Join(tmpDir, "authority.pem")) + + // Verify the CA certificate properties + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.NotNil(t, x509Cert) + + require.NoError(t, err) + assert.True(t, x509Cert.IsCA) + assert.Equal(t, "Thruster Local CA", x509Cert.Subject.Organization[0]) +} + +func TestLocalTLSProvider_ReuseCA(t *testing.T) { + tmpDir := t.TempDir() + + hello := &tls.ClientHelloInfo{ + ServerName: "example.com", + } + + // This should create a new CA + provider1 := NewLocalTLSProvider(tmpDir) + config1 := provider1.TLSConfig() + cert1, err := config1.GetCertificate(hello) + require.NoError(t, err) + require.NotNil(t, cert1) + + // This should reuse the previous CA + provider2 := NewLocalTLSProvider(tmpDir) + config2 := provider2.TLSConfig() + cert2, err := config2.GetCertificate(hello) + require.NoError(t, err) + require.NotNil(t, cert2) + + // Should be the same certificate (index 0 is the CA) + assert.Equal(t, cert1.Certificate[0], cert2.Certificate[0]) +} + +func TestLocalTLSProvider_EndToEnd(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + // Create a test server with the TLS config + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("Hello, TLS!")) + require.NoError(t, err) + }) + + server := httptest.NewUnstartedServer(handler) + server.TLS = provider.TLSConfig() + server.StartTLS() + defer server.Close() + + // Create a client that accepts our self-signed certificate + client := server.Client() + + // Make a request + resp, err := client.Get(server.URL) + require.NoError(t, err) + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestLocalTLSProvider_CreatesStorageDirectory(t *testing.T) { + tmpDir := t.TempDir() + storageDir := filepath.Join(tmpDir, "nested", "storage", "path") + provider := NewLocalTLSProvider(storageDir) + + config := provider.TLSConfig() + cert, err := config.GetCertificate(&tls.ClientHelloInfo{ServerName: "example.com"}) + + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify the nested directory was created + info, err := os.Stat(storageDir) + require.NoError(t, err) + assert.True(t, info.IsDir()) +} + +func TestLocalTLSProvider_WithInternationalDomain(t *testing.T) { + tmpDir := t.TempDir() + provider := NewLocalTLSProvider(tmpDir) + + hello := &tls.ClientHelloInfo{ + ServerName: "zürich.example.com", + } + + config := provider.TLSConfig() + cert, err := config.GetCertificate(hello) + + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify the certificate has the punycode-encoded domain + x509Cert, err := x509.ParseCertificate(cert.Certificate[1]) + require.NoError(t, err) + assert.Contains(t, x509Cert.DNSNames, "xn--zrich-kva.example.com") +} diff --git a/internal/server.go b/internal/server.go index d0f8cd4..98e234a 100644 --- a/internal/server.go +++ b/internal/server.go @@ -2,15 +2,11 @@ package internal import ( "context" - "encoding/base64" "fmt" "log/slog" "net" "net/http" "time" - - "golang.org/x/crypto/acme" - "golang.org/x/crypto/acme/autocert" ) type Server struct { @@ -31,14 +27,13 @@ func (s *Server) Start() error { httpAddress := fmt.Sprintf(":%d", s.config.HttpPort) httpsAddress := fmt.Sprintf(":%d", s.config.HttpsPort) - if s.config.HasTLS() { - manager := s.certManager() - + tlsProvider := s.tlsProvider() + if tlsProvider != nil { s.httpServer = s.defaultHttpServer(httpAddress) - s.httpServer.Handler = manager.HTTPHandler(http.HandlerFunc(httpRedirectHandler)) + s.httpServer.Handler = tlsProvider.HTTPHandler(http.HandlerFunc(httpRedirectHandler)) s.httpsServer = s.defaultHttpServer(httpsAddress) - s.httpsServer.TLSConfig = manager.TLSConfig() + s.httpsServer.TLSConfig = tlsProvider.TLSConfig() s.httpsServer.Handler = s.handler httpListener, err := net.Listen("tcp", httpAddress) @@ -89,36 +84,20 @@ func (s *Server) Stop() { } } -func (s *Server) certManager() *autocert.Manager { - client := &acme.Client{DirectoryURL: s.config.ACMEDirectoryURL} - binding := s.externalAccountBinding() - - slog.Debug("TLS: initializing", "directory", client.DirectoryURL, "using_eab", binding != nil) - - return &autocert.Manager{ - Cache: autocert.DirCache(s.config.StoragePath), - Client: client, - ExternalAccountBinding: binding, - HostPolicy: autocert.HostWhitelist(s.config.TLSDomains...), - Prompt: autocert.AcceptTOS, - } -} - -func (s *Server) externalAccountBinding() *acme.ExternalAccountBinding { - if s.config.EAB_KID == "" || s.config.EAB_HMACKey == "" { - return nil - } - - key, err := base64.RawURLEncoding.DecodeString(s.config.EAB_HMACKey) - if err != nil { - slog.Error("Error decoding EAB_HMACKey", "error", err) +func (s *Server) tlsProvider() TLSProvider { + if !s.config.HasTLS() { return nil } - - return &acme.ExternalAccountBinding{ - KID: s.config.EAB_KID, - Key: key, + if s.config.TLSLocal { + return NewLocalTLSProvider(s.config.StoragePath) } + return NewAutocertTLSProvider( + s.config.StoragePath, + s.config.TLSDomains, + s.config.ACMEDirectoryURL, + s.config.EAB_KID, + s.config.EAB_HMACKey, + ) } func (s *Server) defaultHttpServer(addr string) *http.Server { diff --git a/internal/server_test.go b/internal/server_test.go index 4b41a35..9783b8f 100644 --- a/internal/server_test.go +++ b/internal/server_test.go @@ -57,6 +57,33 @@ func TestServerDefaultCannotMakeH2CRequest(t *testing.T) { assert.Contains(t, err.Error(), "http2: failed reading the frame payload") } +func TestServerDefaultHasNoTLSProvider(t *testing.T) { + config, err := NewConfig() + require.NoError(t, err) + + server := NewServer(config, nil) + assert.Nil(t, server.tlsProvider()) +} + +func TestServerWithAutocertTLSProvider(t *testing.T) { + config, err := NewConfig() + require.NoError(t, err) + + config.TLSDomains = []string{"example.com"} + server := NewServer(config, nil) + assert.NotNil(t, server.tlsProvider()) +} + +func TestServerWithLocalTLSProvider(t *testing.T) { + config, err := NewConfig() + require.NoError(t, err) + + config.TLSDomains = []string{"example.com"} + config.TLSLocal = true + server := NewServer(config, nil) + assert.NotNil(t, server.tlsProvider()) +} + func makeRoundTripH2cRequest(t *testing.T, h2cEnabled bool) (*http.Response, error) { upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "HTTP/1.1", r.Proto, "The upstream should still be serving http/1.1") diff --git a/internal/tls_provider.go b/internal/tls_provider.go new file mode 100644 index 0000000..2a7e0c5 --- /dev/null +++ b/internal/tls_provider.go @@ -0,0 +1,11 @@ +package internal + +import ( + "crypto/tls" + "net/http" +) + +type TLSProvider interface { + HTTPHandler(h http.Handler) http.Handler + TLSConfig() *tls.Config +}