diff --git a/README.md b/README.md index 4167010..885ed7b 100644 --- a/README.md +++ b/README.md @@ -87,6 +87,11 @@ environment variables that you can set. | `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` | | `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 | | `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 | +| `HTTP_HEALTH_PATH` | The http health path to check before start port listening. | None | +| `HTTP_HEALTH_HOST` | The http health host to check before start port listening. | 127.0.0.1 | +| `HTTP_HEALTH_INTERVAL` | The http health path check interval (seconds). | 1 | +| `HTTP_HEALTH_TIMEOUT` | The http health path check timeout (seconds). | 1 | +| `HTTP_HEALTH_DEADLINE` | The http health path deadline interval (seconds), after which thruster will exit with error, if no success response. | 120 | | `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 | | `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 | | `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 | diff --git a/internal/config.go b/internal/config.go index fd2887d..55d62e5 100644 --- a/internal/config.go +++ b/internal/config.go @@ -27,11 +27,15 @@ const ( defaultStoragePath = "./storage/thruster" defaultBadGatewayPage = "./public/502.html" - defaultHttpPort = 80 - defaultHttpsPort = 443 - defaultHttpIdleTimeout = 60 * time.Second - defaultHttpReadTimeout = 30 * time.Second - defaultHttpWriteTimeout = 30 * time.Second + defaultHttpPort = 80 + defaultHttpsPort = 443 + defaultHttpHealthHost = "127.0.0.1" + defaultHttpHealthTimeout = 1 * time.Second + defaultHttpHealthInterval = 1 * time.Second + defaultHttpHealthDeadline = 2 * time.Minute + defaultHttpIdleTimeout = 60 * time.Second + defaultHttpReadTimeout = 30 * time.Second + defaultHttpWriteTimeout = 30 * time.Second defaultH2CEnabled = false @@ -62,11 +66,16 @@ type Config struct { StoragePath string BadGatewayPage string - HttpPort int - HttpsPort int - HttpIdleTimeout time.Duration - HttpReadTimeout time.Duration - HttpWriteTimeout time.Duration + HttpPort int + HttpsPort int + HttpHealthHost string + HttpHealthPath string + HttpHealthTimeout time.Duration + HttpHealthInterval time.Duration + HttpHealthDeadline time.Duration + HttpIdleTimeout time.Duration + HttpReadTimeout time.Duration + HttpWriteTimeout time.Duration H2CEnabled bool @@ -89,7 +98,7 @@ func NewConfig() (*Config, error) { config := &Config{ TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort), UpstreamCommand: os.Args[1], - UpstreamArgs: os.Args[2:], + UpstreamArgs: append([]string{}, os.Args[2:]...), CacheSizeBytes: getEnvInt("CACHE_SIZE", defaultCacheSize), MaxCacheItemSizeBytes: getEnvInt("MAX_CACHE_ITEM_SIZE", defaultMaxCacheItemSizeBytes), @@ -106,11 +115,16 @@ func NewConfig() (*Config, error) { StoragePath: getEnvString("STORAGE_PATH", defaultStoragePath), BadGatewayPage: getEnvString("BAD_GATEWAY_PAGE", defaultBadGatewayPage), - HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort), - HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort), - HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout), - HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), - HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), + HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort), + HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort), + HttpHealthHost: getEnvString("HTTP_HEALTH_HOST", defaultHttpHealthHost), + HttpHealthPath: getEnvString("HTTP_HEALTH_PATH", ""), + HttpHealthInterval: getEnvDuration("HTTP_HEALTH_INTERVAL", defaultHttpHealthInterval), + HttpHealthTimeout: getEnvDuration("HTTP_HEALTH_TIMEOUT", defaultHttpHealthTimeout), + HttpHealthDeadline: getEnvDuration("HTTP_HEALTH_DEADLINE", defaultHttpHealthDeadline), + HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout), + HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout), + HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout), H2CEnabled: getEnvBool("H2C_ENABLED", defaultH2CEnabled), diff --git a/internal/config_test.go b/internal/config_test.go index 5dc1d9b..f8c6100 100644 --- a/internal/config_test.go +++ b/internal/config_test.go @@ -105,6 +105,11 @@ func TestConfig_defaults(t *testing.T) { assert.Equal(t, "echo", c.UpstreamCommand) assert.Equal(t, defaultCacheSize, c.CacheSizeBytes) assert.Equal(t, slog.LevelInfo, c.LogLevel) + assert.Equal(t, "", c.HttpHealthPath) + assert.Equal(t, "127.0.0.1", c.HttpHealthHost) + assert.Equal(t, 1*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 1*time.Second, c.HttpHealthInterval) + assert.Equal(t, 2*time.Minute, c.HttpHealthDeadline) assert.Equal(t, false, c.H2CEnabled) } @@ -118,6 +123,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { usingEnvVar(t, "DEBUG", "1") usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory") usingEnvVar(t, "LOG_REQUESTS", "false") + usingEnvVar(t, "HTTP_HEALTH_PATH", "/health") + usingEnvVar(t, "HTTP_HEALTH_HOST", "localhost") + usingEnvVar(t, "HTTP_HEALTH_INTERVAL", "3") + usingEnvVar(t, "HTTP_HEALTH_TIMEOUT", "4") + usingEnvVar(t, "HTTP_HEALTH_DEADLINE", "60") usingEnvVar(t, "H2C_ENABLED", "true") usingEnvVar(t, "GZIP_COMPRESSION_DISABLE_ON_AUTH", "true") usingEnvVar(t, "GZIP_COMPRESSION_JITTER", "64") @@ -132,6 +142,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) { assert.Equal(t, false, c.GzipCompressionEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL) + assert.Equal(t, "/health", c.HttpHealthPath) + assert.Equal(t, "localhost", c.HttpHealthHost) + assert.Equal(t, 3*time.Second, c.HttpHealthInterval) + assert.Equal(t, 4*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 60*time.Second, c.HttpHealthDeadline) assert.Equal(t, false, c.LogRequests) assert.Equal(t, true, c.H2CEnabled) assert.Equal(t, true, c.GzipCompressionDisableOnAuth) @@ -146,6 +161,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { usingEnvVar(t, "THRUSTER_X_SENDFILE_ENABLED", "0") usingEnvVar(t, "THRUSTER_DEBUG", "1") usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_PATH", "/health") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_HOST", "localhost") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_INTERVAL", "3") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_TIMEOUT", "4") + usingEnvVar(t, "THRUSTER_HTTP_HEALTH_DEADLINE", "60") usingEnvVar(t, "THRUSTER_H2C_ENABLED", "1") c, err := NewConfig() @@ -157,6 +177,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) { assert.Equal(t, false, c.XSendfileEnabled) assert.Equal(t, slog.LevelDebug, c.LogLevel) assert.Equal(t, false, c.LogRequests) + assert.Equal(t, "/health", c.HttpHealthPath) + assert.Equal(t, "localhost", c.HttpHealthHost) + assert.Equal(t, 3*time.Second, c.HttpHealthInterval) + assert.Equal(t, 4*time.Second, c.HttpHealthTimeout) + assert.Equal(t, 60*time.Second, c.HttpHealthDeadline) assert.Equal(t, true, c.H2CEnabled) } @@ -171,6 +196,20 @@ func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing. assert.Equal(t, 4000, c.TargetPort) } +func TestConfig_defaults_are_used_if_strconv_fails(t *testing.T) { + usingProgramArgs(t, "thruster", "echo", "hello") + usingEnvVar(t, "TARGET_PORT", "should-be-an-int") + usingEnvVar(t, "HTTP_IDLE_TIMEOUT", "should-be-a-duration") + usingEnvVar(t, "X_SENDFILE_ENABLED", "should-be-a-bool") + + c, err := NewConfig() + require.NoError(t, err) + + assert.Equal(t, 3000, c.TargetPort) + assert.Equal(t, 60*time.Second, c.HttpIdleTimeout) + assert.Equal(t, true, c.XSendfileEnabled) +} + func TestConfig_return_error_when_no_upstream_command(t *testing.T) { usingProgramArgs(t, "thruster") diff --git a/internal/service.go b/internal/service.go index c34ba73..8e6fb4c 100644 --- a/internal/service.go +++ b/internal/service.go @@ -1,16 +1,27 @@ package internal import ( + "context" "fmt" "log/slog" + "net/http" "net/url" "os" + "os/signal" + "syscall" + "time" ) type Service struct { config *Config } +// Represents the result of the upstream process execution. +type upstreamResult struct { + exitCode int + err error +} + func NewService(config *Config) *Service { return &Service{ config: config, @@ -36,23 +47,136 @@ func (s *Service) Run() int { server := NewServer(s.config, handler) upstream := NewUpstreamProcess(s.config.UpstreamCommand, s.config.UpstreamArgs...) + s.setEnvironment() + + // Channel to receive the result from the upstream process goroutine. + resultChan := make(chan upstreamResult, 1) + + // Run the upstream process in a separate goroutine + // This allows us to perform health checks while it starts up + go func() { + exitCode, err := upstream.Run() + resultChan <- upstreamResult{exitCode: exitCode, err: err} + }() + + // If a health check path is configured, wait for the upstream to become healthy + if s.config.HttpHealthPath != "" { + if err := s.performHealthCheck(resultChan); err != nil { + slog.Error("Upstream health check failed", "error", err) + // At this point, the upstream process is running but unhealthy + if err := upstream.Signal(syscall.SIGTERM); err != nil { + slog.Error("Failed to send signal to upstream process", "error", err) + } + return 1 + } + slog.Info("Upstream service is healthy, starting proxy server.") + } + + // Now that the upstream is ready, start the main proxy server if err := server.Start(); err != nil { return 1 } defer server.Stop() - s.setEnvironment() + // Delegate the waiting and signal handling to the new function + return s.awaitTermination(upstream, resultChan) +} + +// Private + +func (s *Service) awaitTermination(upstream *UpstreamProcess, resultChan <-chan upstreamResult) int { + signalChan := make(chan os.Signal, 1) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case result := <-resultChan: + // The upstream process finished on its own. + slog.Info("Wrapped process exited on its own.", "exit_code", result.exitCode) + if result.err != nil { + slog.Error("Wrapped process failed", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", result.err) + return 1 + } + return result.exitCode + + case sig := <-signalChan: + // An OS signal was caught + slog.Info("Received signal, shutting down.", "signal", sig.String()) + + // Relay the signal to the child process to allow for graceful shutdown. + slog.Info("Relaying signal to upstream process...") + if err := upstream.Signal(sig); err != nil { + slog.Error("Failed to send signal to upstream process", "error", err) + } - exitCode, err := upstream.Run() - if err != nil { - slog.Error("Failed to start wrapped process", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", err) + // Give the upstream process a moment to shut down gracefully + // before the defer server.Stop() forcefully cleans up. + select { + case <-resultChan: + slog.Info("Upstream process terminated gracefully after signal.") + case <-time.After(10 * time.Second): + slog.Warn("Upstream process did not terminate within 10 seconds of signal.") + } + + // Exit with a non-zero status code to indicate termination by signal. return 1 } - - return exitCode } -// Private +// performHealthCheck polls the health check endpoint until it gets a 200 OK +func (s *Service) performHealthCheck(resultChan <-chan upstreamResult) error { + // Create a context with a 2-minute timeout (default) for the entire health check process + ctx, cancel := context.WithTimeout(context.Background(), s.config.HttpHealthDeadline) + defer cancel() + + // We assume the upstream server binds to the target URL's host + healthCheckURL := fmt.Sprintf("http://%s:%d%s", s.config.HttpHealthHost, s.config.TargetPort, s.config.HttpHealthPath) + slog.Info("Starting health checks", "url", healthCheckURL) + + // Use a ticker to check every second (default) + ticker := time.NewTicker(s.config.HttpHealthInterval) + defer ticker.Stop() + + // Create an HTTP client with a short timeout for individual requests + client := &http.Client{ + Timeout: s.config.HttpHealthTimeout, + } + + for { + select { + case <-ctx.Done(): + // Deadline exceeded + return fmt.Errorf("health check timed out after %v", s.config.HttpHealthDeadline) + + case result := <-resultChan: + // The upstream process exited before it became healthy + return fmt.Errorf("upstream process exited prematurely with code %d: %w", result.exitCode, result.err) + + case <-ticker.C: + // Ticker fired, time to perform a check + req, err := http.NewRequestWithContext(ctx, "GET", healthCheckURL, nil) + if err != nil { + return fmt.Errorf("failed to create health check request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + // This is expected while the server is starting up (e.g., "connection refused") + slog.Debug("Health check attempt failed, retrying...", "error", err) + continue + } + + // Don't forget to close the body + resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + // Success! + return nil + } + + slog.Debug("Health check received non-200 status", "status_code", resp.StatusCode) + } + } +} func (s *Service) cache() Cache { return NewMemoryCache(s.config.CacheSizeBytes, s.config.MaxCacheItemSizeBytes)