Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 68 additions & 60 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,44 @@ func run(cfg config.Config) error {
defer cleanupEphemeralStateDir(cfg.StateDir)
}

if cfg.AutoInstance {
logger.Info("auto-instance mode enabled",
"local_addr", cfg.LocalAddr,
"hostname", cfg.Hostname,
"state_dir", cfg.StateDir,
"ephemeral_state", cfg.EphemeralState)
server, err := initTailscale(cfg)
if err != nil {
return err
}

listener, err := net.Listen("tcp", cfg.LocalAddr)
if err != nil {
_ = server.Close()
return fmt.Errorf("bind %s: %w", cfg.LocalAddr, err)
}

var ready atomic.Bool
var healthServer *http.Server
if cfg.HealthAddr != "" {
healthServer = health.StartServer(cfg.HealthAddr, &ready, logger)
}

printBanner(cfg)

sigCtx, sigCancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer sigCancel()

go handleShutdown(sigCtx, &ready, listener, healthServer)

ready.Store(true)
var activeConns sync.WaitGroup
errAccept := proxy.AcceptLoop(listener, server, cfg, &activeConns, logger)

drainActiveConnections(cfg, &activeConns)

if err := server.Close(); err != nil {
logger.Error("error closing tsnet server", "error", err)
}

return errAccept
}

func initTailscale(cfg config.Config) (*tsnet.Server, error) {
var tsnetLogf func(string, ...any)
if cfg.Verbose {
tsnetLogf = func(format string, args ...any) {
Expand All @@ -180,74 +210,52 @@ func run(cfg config.Config) error {

status, err := server.Up(ctx)
if err != nil {
return fmt.Errorf("tailscale init failed: %w", err)
return nil, fmt.Errorf("tailscale init failed: %w", err)
}
logger.Info("tailscale ready", "ip", status.Self.TailscaleIPs[0])
return server, nil
}

listener, err := net.Listen("tcp", cfg.LocalAddr)
if err != nil {
return fmt.Errorf("bind %s: %w", cfg.LocalAddr, err)
func handleShutdown(ctx context.Context, ready *atomic.Bool, listener net.Listener, healthServer *http.Server) {
<-ctx.Done()
logger.Info("shutting down")
ready.Store(false)
if err := listener.Close(); err != nil {
logger.Error("error closing listener", "error", err)
}

// Start health server if configured
var ready atomic.Bool
var healthServer *http.Server

if cfg.HealthAddr != "" {
healthServer = health.StartServer(cfg.HealthAddr, &ready, logger)
if healthServer != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := healthServer.Shutdown(shutdownCtx); err != nil {
logger.Error("error closing health server", "error", err)
}
}
}

printBanner(cfg)
func drainActiveConnections(cfg config.Config, wg *sync.WaitGroup) {
if cfg.DrainTimeout <= 0 {
return
}

sigCtx, sigCancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer sigCancel()
logger.Info("draining active connections", "timeout", cfg.DrainTimeout)
drainCtx, drainCancel := context.WithTimeout(context.Background(), cfg.DrainTimeout)
defer drainCancel()

done := make(chan struct{})
go func() {
<-sigCtx.Done()
logger.Info("shutting down")
ready.Store(false)
if err := listener.Close(); err != nil {
logger.Error("error closing listener", "error", err)
}
if healthServer != nil {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if err := healthServer.Shutdown(shutdownCtx); err != nil {
logger.Error("error closing health server", "error", err)
}
}
wg.Wait()
close(done)
}()

ready.Store(true)
var activeConns sync.WaitGroup
errAccept := proxy.AcceptLoop(listener, server, cfg, &activeConns, logger)

if cfg.DrainTimeout > 0 {
logger.Info("draining active connections", "timeout", cfg.DrainTimeout)
drainCtx, drainCancel := context.WithTimeout(context.Background(), cfg.DrainTimeout)
defer drainCancel()

done := make(chan struct{})
go func() {
activeConns.Wait()
close(done)
}()

select {
case <-done:
logger.Info("all active connections drained gracefully")
case <-drainCtx.Done():
logger.Warn("drain timeout exceeded, forcing shutdown")
}
select {
case <-done:
logger.Info("all active connections drained gracefully")
case <-drainCtx.Done():
logger.Warn("drain timeout exceeded, forcing shutdown")
}

if err := server.Close(); err != nil {
logger.Error("error closing tsnet server", "error", err)
}

return errAccept
}


func printBanner(cfg config.Config) {
fmt.Println()
fmt.Println(" +---------------------------------------+")
Expand Down
Loading