From 30c735cb1f619f2aa418e84706ca9bc3906cb901 Mon Sep 17 00:00:00 2001 From: Manu Date: Sat, 7 Mar 2026 20:55:14 -0700 Subject: [PATCH] refactor: reduce run function complexity ## Summary - Extract initTailscale, handleShutdown, and drainActiveConnections from the run function - Reduce cyclomatic complexity of run function from 19 to 7 - Improve readability and maintainability of the main entry point --- main.go | 128 ++++++++++++++++++++++++++++++-------------------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/main.go b/main.go index 62912bf..ca4d2d0 100644 --- a/main.go +++ b/main.go @@ -149,14 +149,44 @@ func run(cfg config.Config) error { defer cleanupEphemeralStateDir(cfg.StateDir) } - if cfg.AutoInstance { - logger.Info("auto-instance mode enabled", - "local_addr", cfg.LocalAddr, - "hostname", cfg.Hostname, - "state_dir", cfg.StateDir, - "ephemeral_state", cfg.EphemeralState) + server, err := initTailscale(cfg) + if err != nil { + return err + } + + listener, err := net.Listen("tcp", cfg.LocalAddr) + if err != nil { + _ = server.Close() + return fmt.Errorf("bind %s: %w", cfg.LocalAddr, err) + } + + var ready atomic.Bool + var healthServer *http.Server + if cfg.HealthAddr != "" { + healthServer = health.StartServer(cfg.HealthAddr, &ready, logger) } + printBanner(cfg) + + sigCtx, sigCancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer sigCancel() + + go handleShutdown(sigCtx, &ready, listener, healthServer) + + ready.Store(true) + var activeConns sync.WaitGroup + errAccept := proxy.AcceptLoop(listener, server, cfg, &activeConns, logger) + + drainActiveConnections(cfg, &activeConns) + + if err := server.Close(); err != nil { + logger.Error("error closing tsnet server", "error", err) + } + + return errAccept +} + +func initTailscale(cfg config.Config) (*tsnet.Server, error) { var tsnetLogf func(string, ...any) if cfg.Verbose { tsnetLogf = func(format string, args ...any) { @@ -180,74 +210,52 @@ func run(cfg config.Config) error { status, err := server.Up(ctx) if err != nil { - return fmt.Errorf("tailscale init failed: %w", err) + return nil, fmt.Errorf("tailscale init failed: %w", err) } logger.Info("tailscale ready", "ip", status.Self.TailscaleIPs[0]) + return server, nil +} - listener, err := net.Listen("tcp", cfg.LocalAddr) - if err != nil { - return fmt.Errorf("bind %s: %w", cfg.LocalAddr, err) +func handleShutdown(ctx context.Context, ready *atomic.Bool, listener net.Listener, healthServer *http.Server) { + <-ctx.Done() + logger.Info("shutting down") + ready.Store(false) + if err := listener.Close(); err != nil { + logger.Error("error closing listener", "error", err) } - - // Start health server if configured - var ready atomic.Bool - var healthServer *http.Server - - if cfg.HealthAddr != "" { - healthServer = health.StartServer(cfg.HealthAddr, &ready, logger) + if healthServer != nil { + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer shutdownCancel() + if err := healthServer.Shutdown(shutdownCtx); err != nil { + logger.Error("error closing health server", "error", err) + } } +} - printBanner(cfg) +func drainActiveConnections(cfg config.Config, wg *sync.WaitGroup) { + if cfg.DrainTimeout <= 0 { + return + } - sigCtx, sigCancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - defer sigCancel() + logger.Info("draining active connections", "timeout", cfg.DrainTimeout) + drainCtx, drainCancel := context.WithTimeout(context.Background(), cfg.DrainTimeout) + defer drainCancel() + done := make(chan struct{}) go func() { - <-sigCtx.Done() - logger.Info("shutting down") - ready.Store(false) - if err := listener.Close(); err != nil { - logger.Error("error closing listener", "error", err) - } - if healthServer != nil { - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second) - defer shutdownCancel() - if err := healthServer.Shutdown(shutdownCtx); err != nil { - logger.Error("error closing health server", "error", err) - } - } + wg.Wait() + close(done) }() - ready.Store(true) - var activeConns sync.WaitGroup - errAccept := proxy.AcceptLoop(listener, server, cfg, &activeConns, logger) - - if cfg.DrainTimeout > 0 { - logger.Info("draining active connections", "timeout", cfg.DrainTimeout) - drainCtx, drainCancel := context.WithTimeout(context.Background(), cfg.DrainTimeout) - defer drainCancel() - - done := make(chan struct{}) - go func() { - activeConns.Wait() - close(done) - }() - - select { - case <-done: - logger.Info("all active connections drained gracefully") - case <-drainCtx.Done(): - logger.Warn("drain timeout exceeded, forcing shutdown") - } + select { + case <-done: + logger.Info("all active connections drained gracefully") + case <-drainCtx.Done(): + logger.Warn("drain timeout exceeded, forcing shutdown") } - - if err := server.Close(); err != nil { - logger.Error("error closing tsnet server", "error", err) - } - - return errAccept } + func printBanner(cfg config.Config) { fmt.Println() fmt.Println(" +---------------------------------------+")