Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ environment variables that you can set.
| `BAD_GATEWAY_PAGE` | Path to an HTML file to serve when the backend server returns a 502 Bad Gateway error. If there is no file at the specific path, Thruster will serve an empty 502 response instead. Because Thruster boots very quickly, a custom page can be a useful way to show that your application is starting up. | `./public/502.html` |
| `HTTP_PORT` | The port to listen on for HTTP traffic. | 80 |
| `HTTPS_PORT` | The port to listen on for HTTPS traffic. | 443 |
| `HTTP_HEALTH_PATH` | The http health path to check before start port listening. | None |
| `HTTP_HEALTH_HOST` | The http health host to check before start port listening. | 127.0.0.1 |
| `HTTP_HEALTH_INTERVAL` | The http health path check interval (seconds). | 1 |
| `HTTP_HEALTH_TIMEOUT` | The http health path check timeout (seconds). | 1 |
| `HTTP_HEALTH_DEADLINE` | The http health path deadline interval (seconds), after which thruster will exit with error, if no success response. | 120 |
| `HTTP_IDLE_TIMEOUT` | The maximum time in seconds that a client can be idle before the connection is closed. | 60 |
| `HTTP_READ_TIMEOUT` | The maximum time in seconds that a client can take to send the request headers and body. | 30 |
| `HTTP_WRITE_TIMEOUT` | The maximum time in seconds during which the client must read the response. | 30 |
Expand Down
46 changes: 30 additions & 16 deletions internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ const (
defaultStoragePath = "./storage/thruster"
defaultBadGatewayPage = "./public/502.html"

defaultHttpPort = 80
defaultHttpsPort = 443
defaultHttpIdleTimeout = 60 * time.Second
defaultHttpReadTimeout = 30 * time.Second
defaultHttpWriteTimeout = 30 * time.Second
defaultHttpPort = 80
defaultHttpsPort = 443
defaultHttpHealthHost = "127.0.0.1"
defaultHttpHealthTimeout = 1 * time.Second
defaultHttpHealthInterval = 1 * time.Second
defaultHttpHealthDeadline = 2 * time.Minute
defaultHttpIdleTimeout = 60 * time.Second
defaultHttpReadTimeout = 30 * time.Second
defaultHttpWriteTimeout = 30 * time.Second

defaultH2CEnabled = false

Expand Down Expand Up @@ -62,11 +66,16 @@ type Config struct {
StoragePath string
BadGatewayPage string

HttpPort int
HttpsPort int
HttpIdleTimeout time.Duration
HttpReadTimeout time.Duration
HttpWriteTimeout time.Duration
HttpPort int
HttpsPort int
HttpHealthHost string
HttpHealthPath string
HttpHealthTimeout time.Duration
HttpHealthInterval time.Duration
HttpHealthDeadline time.Duration
HttpIdleTimeout time.Duration
HttpReadTimeout time.Duration
HttpWriteTimeout time.Duration

H2CEnabled bool

Expand All @@ -89,7 +98,7 @@ func NewConfig() (*Config, error) {
config := &Config{
TargetPort: getEnvInt("TARGET_PORT", defaultTargetPort),
UpstreamCommand: os.Args[1],
UpstreamArgs: os.Args[2:],
UpstreamArgs: append([]string{}, os.Args[2:]...),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prevent UpstreamArgs mutation


CacheSizeBytes: getEnvInt("CACHE_SIZE", defaultCacheSize),
MaxCacheItemSizeBytes: getEnvInt("MAX_CACHE_ITEM_SIZE", defaultMaxCacheItemSizeBytes),
Expand All @@ -106,11 +115,16 @@ func NewConfig() (*Config, error) {
StoragePath: getEnvString("STORAGE_PATH", defaultStoragePath),
BadGatewayPage: getEnvString("BAD_GATEWAY_PAGE", defaultBadGatewayPage),

HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort),
HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort),
HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout),
HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout),
HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout),
HttpPort: getEnvInt("HTTP_PORT", defaultHttpPort),
HttpsPort: getEnvInt("HTTPS_PORT", defaultHttpsPort),
HttpHealthHost: getEnvString("HTTP_HEALTH_HOST", defaultHttpHealthHost),
HttpHealthPath: getEnvString("HTTP_HEALTH_PATH", ""),
HttpHealthInterval: getEnvDuration("HTTP_HEALTH_INTERVAL", defaultHttpHealthInterval),
HttpHealthTimeout: getEnvDuration("HTTP_HEALTH_TIMEOUT", defaultHttpHealthTimeout),
HttpHealthDeadline: getEnvDuration("HTTP_HEALTH_DEADLINE", defaultHttpHealthDeadline),
HttpIdleTimeout: getEnvDuration("HTTP_IDLE_TIMEOUT", defaultHttpIdleTimeout),
HttpReadTimeout: getEnvDuration("HTTP_READ_TIMEOUT", defaultHttpReadTimeout),
HttpWriteTimeout: getEnvDuration("HTTP_WRITE_TIMEOUT", defaultHttpWriteTimeout),

H2CEnabled: getEnvBool("H2C_ENABLED", defaultH2CEnabled),

Expand Down
39 changes: 39 additions & 0 deletions internal/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ func TestConfig_defaults(t *testing.T) {
assert.Equal(t, "echo", c.UpstreamCommand)
assert.Equal(t, defaultCacheSize, c.CacheSizeBytes)
assert.Equal(t, slog.LevelInfo, c.LogLevel)
assert.Equal(t, "", c.HttpHealthPath)
assert.Equal(t, "127.0.0.1", c.HttpHealthHost)
assert.Equal(t, 1*time.Second, c.HttpHealthTimeout)
assert.Equal(t, 1*time.Second, c.HttpHealthInterval)
assert.Equal(t, 2*time.Minute, c.HttpHealthDeadline)
assert.Equal(t, false, c.H2CEnabled)
}

Expand All @@ -118,6 +123,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) {
usingEnvVar(t, "DEBUG", "1")
usingEnvVar(t, "ACME_DIRECTORY", "https://acme-staging-v02.api.letsencrypt.org/directory")
usingEnvVar(t, "LOG_REQUESTS", "false")
usingEnvVar(t, "HTTP_HEALTH_PATH", "/health")
usingEnvVar(t, "HTTP_HEALTH_HOST", "localhost")
usingEnvVar(t, "HTTP_HEALTH_INTERVAL", "3")
usingEnvVar(t, "HTTP_HEALTH_TIMEOUT", "4")
usingEnvVar(t, "HTTP_HEALTH_DEADLINE", "60")
usingEnvVar(t, "H2C_ENABLED", "true")
usingEnvVar(t, "GZIP_COMPRESSION_DISABLE_ON_AUTH", "true")
usingEnvVar(t, "GZIP_COMPRESSION_JITTER", "64")
Expand All @@ -132,6 +142,11 @@ func TestConfig_override_defaults_with_env_vars(t *testing.T) {
assert.Equal(t, false, c.GzipCompressionEnabled)
assert.Equal(t, slog.LevelDebug, c.LogLevel)
assert.Equal(t, "https://acme-staging-v02.api.letsencrypt.org/directory", c.ACMEDirectoryURL)
assert.Equal(t, "/health", c.HttpHealthPath)
assert.Equal(t, "localhost", c.HttpHealthHost)
assert.Equal(t, 3*time.Second, c.HttpHealthInterval)
assert.Equal(t, 4*time.Second, c.HttpHealthTimeout)
assert.Equal(t, 60*time.Second, c.HttpHealthDeadline)
assert.Equal(t, false, c.LogRequests)
assert.Equal(t, true, c.H2CEnabled)
assert.Equal(t, true, c.GzipCompressionDisableOnAuth)
Expand All @@ -146,6 +161,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) {
usingEnvVar(t, "THRUSTER_X_SENDFILE_ENABLED", "0")
usingEnvVar(t, "THRUSTER_DEBUG", "1")
usingEnvVar(t, "THRUSTER_LOG_REQUESTS", "0")
usingEnvVar(t, "THRUSTER_HTTP_HEALTH_PATH", "/health")
usingEnvVar(t, "THRUSTER_HTTP_HEALTH_HOST", "localhost")
usingEnvVar(t, "THRUSTER_HTTP_HEALTH_INTERVAL", "3")
usingEnvVar(t, "THRUSTER_HTTP_HEALTH_TIMEOUT", "4")
usingEnvVar(t, "THRUSTER_HTTP_HEALTH_DEADLINE", "60")
usingEnvVar(t, "THRUSTER_H2C_ENABLED", "1")

c, err := NewConfig()
Expand All @@ -157,6 +177,11 @@ func TestConfig_override_defaults_with_env_vars_using_prefix(t *testing.T) {
assert.Equal(t, false, c.XSendfileEnabled)
assert.Equal(t, slog.LevelDebug, c.LogLevel)
assert.Equal(t, false, c.LogRequests)
assert.Equal(t, "/health", c.HttpHealthPath)
assert.Equal(t, "localhost", c.HttpHealthHost)
assert.Equal(t, 3*time.Second, c.HttpHealthInterval)
assert.Equal(t, 4*time.Second, c.HttpHealthTimeout)
assert.Equal(t, 60*time.Second, c.HttpHealthDeadline)
assert.Equal(t, true, c.H2CEnabled)
}

Expand All @@ -171,6 +196,20 @@ func TestConfig_prefixed_variables_take_precedence_over_non_prefixed(t *testing.
assert.Equal(t, 4000, c.TargetPort)
}

func TestConfig_defaults_are_used_if_strconv_fails(t *testing.T) {
usingProgramArgs(t, "thruster", "echo", "hello")
usingEnvVar(t, "TARGET_PORT", "should-be-an-int")
usingEnvVar(t, "HTTP_IDLE_TIMEOUT", "should-be-a-duration")
usingEnvVar(t, "X_SENDFILE_ENABLED", "should-be-a-bool")

c, err := NewConfig()
require.NoError(t, err)

assert.Equal(t, 3000, c.TargetPort)
assert.Equal(t, 60*time.Second, c.HttpIdleTimeout)
assert.Equal(t, true, c.XSendfileEnabled)
}

func TestConfig_return_error_when_no_upstream_command(t *testing.T) {
usingProgramArgs(t, "thruster")

Expand Down
138 changes: 131 additions & 7 deletions internal/service.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
package internal

import (
"context"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"os/signal"
"syscall"
"time"
)

type Service struct {
config *Config
}

// Represents the result of the upstream process execution.
type upstreamResult struct {
exitCode int
err error
}

func NewService(config *Config) *Service {
return &Service{
config: config,
Expand All @@ -36,23 +47,136 @@ func (s *Service) Run() int {
server := NewServer(s.config, handler)
upstream := NewUpstreamProcess(s.config.UpstreamCommand, s.config.UpstreamArgs...)

s.setEnvironment()

// Channel to receive the result from the upstream process goroutine.
resultChan := make(chan upstreamResult, 1)

// Run the upstream process in a separate goroutine
// This allows us to perform health checks while it starts up
go func() {
exitCode, err := upstream.Run()
resultChan <- upstreamResult{exitCode: exitCode, err: err}
}()

// If a health check path is configured, wait for the upstream to become healthy
if s.config.HttpHealthPath != "" {
if err := s.performHealthCheck(resultChan); err != nil {
slog.Error("Upstream health check failed", "error", err)
// At this point, the upstream process is running but unhealthy
if err := upstream.Signal(syscall.SIGTERM); err != nil {
slog.Error("Failed to send signal to upstream process", "error", err)
}
return 1
}
slog.Info("Upstream service is healthy, starting proxy server.")
}

// Now that the upstream is ready, start the main proxy server
if err := server.Start(); err != nil {
return 1
}
defer server.Stop()

s.setEnvironment()
// Delegate the waiting and signal handling to the new function
return s.awaitTermination(upstream, resultChan)
}

// Private

func (s *Service) awaitTermination(upstream *UpstreamProcess, resultChan <-chan upstreamResult) int {
signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM)

select {
case result := <-resultChan:
// The upstream process finished on its own.
slog.Info("Wrapped process exited on its own.", "exit_code", result.exitCode)
if result.err != nil {
slog.Error("Wrapped process failed", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", result.err)
return 1
}
return result.exitCode

case sig := <-signalChan:
// An OS signal was caught
slog.Info("Received signal, shutting down.", "signal", sig.String())

// Relay the signal to the child process to allow for graceful shutdown.
slog.Info("Relaying signal to upstream process...")
if err := upstream.Signal(sig); err != nil {
slog.Error("Failed to send signal to upstream process", "error", err)
}

exitCode, err := upstream.Run()
if err != nil {
slog.Error("Failed to start wrapped process", "command", s.config.UpstreamCommand, "args", s.config.UpstreamArgs, "error", err)
// Give the upstream process a moment to shut down gracefully
// before the defer server.Stop() forcefully cleans up.
select {
case <-resultChan:
slog.Info("Upstream process terminated gracefully after signal.")
case <-time.After(10 * time.Second):
slog.Warn("Upstream process did not terminate within 10 seconds of signal.")
}

// Exit with a non-zero status code to indicate termination by signal.
return 1
}

return exitCode
}

// Private
// performHealthCheck polls the health check endpoint until it gets a 200 OK
func (s *Service) performHealthCheck(resultChan <-chan upstreamResult) error {
// Create a context with a 2-minute timeout (default) for the entire health check process
ctx, cancel := context.WithTimeout(context.Background(), s.config.HttpHealthDeadline)
defer cancel()

// We assume the upstream server binds to the target URL's host
healthCheckURL := fmt.Sprintf("http://%s:%d%s", s.config.HttpHealthHost, s.config.TargetPort, s.config.HttpHealthPath)
slog.Info("Starting health checks", "url", healthCheckURL)

// Use a ticker to check every second (default)
ticker := time.NewTicker(s.config.HttpHealthInterval)
defer ticker.Stop()

// Create an HTTP client with a short timeout for individual requests
client := &http.Client{
Timeout: s.config.HttpHealthTimeout,
}

for {
select {
case <-ctx.Done():
// Deadline exceeded
return fmt.Errorf("health check timed out after %v", s.config.HttpHealthDeadline)

case result := <-resultChan:
// The upstream process exited before it became healthy
return fmt.Errorf("upstream process exited prematurely with code %d: %w", result.exitCode, result.err)

case <-ticker.C:
// Ticker fired, time to perform a check
req, err := http.NewRequestWithContext(ctx, "GET", healthCheckURL, nil)
if err != nil {
return fmt.Errorf("failed to create health check request: %w", err)
}

resp, err := client.Do(req)
if err != nil {
// This is expected while the server is starting up (e.g., "connection refused")
slog.Debug("Health check attempt failed, retrying...", "error", err)
continue
}

// Don't forget to close the body
resp.Body.Close()

if resp.StatusCode == http.StatusOK {
// Success!
return nil
}

slog.Debug("Health check received non-200 status", "status_code", resp.StatusCode)
}
}
}

func (s *Service) cache() Cache {
return NewMemoryCache(s.config.CacheSizeBytes, s.config.MaxCacheItemSizeBytes)
Expand Down