diff --git a/.env.example b/.env.example index 53796bf..0c749ea 100644 --- a/.env.example +++ b/.env.example @@ -18,6 +18,7 @@ TS_TARGET=100.x.x.x:3389 # TS_LOCAL_ADDR=127.0.0.1:33389 # Local bind address # TS_CONTROL_URL= # Custom control plane (e.g. https://vpn.example.com) # TS_IDLE_TIMEOUT= # Close idle conns after this duration (e.g. 30m). Default: disabled. +# TS_DIAL_TIMEOUT=5s # Per-connection target dial timeout (distinct from TS_TIMEOUT which only covers tsnet init). # TS_DIAL_RETRIES=3 # Max retries for transient target dial failures (0 disables). # TS_DIAL_BACKOFF_BASE=1s # Base backoff for retries (multiplied by 2^attempt). # TS_DIAL_BACKOFF_MAX=30s # Cap on backoff duration per retry. diff --git a/internal/config/config.go b/internal/config/config.go index b2eaf61..d22143c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -24,6 +24,7 @@ const ( defaultDialRetries = 3 defaultDialBackoffBase = 1 * time.Second defaultDialBackoffMax = 30 * time.Second + defaultDialTimeout = 5 * time.Second ) // Config holds the bridge configuration. @@ -35,6 +36,7 @@ type Config struct { StateDir string ControlURL string ConnectTimeout time.Duration + DialTimeout time.Duration DrainTimeout time.Duration IdleTimeout time.Duration DialRetries int @@ -70,13 +72,10 @@ func LoadConfig(verboseFlag bool) (Config, error) { return Config{}, err } - idleTimeout, err := parseDurationEnv("TS_IDLE_TIMEOUT", 0) + idleTimeout, dialTimeout, err := parseTimeoutEnvs() if err != nil { return Config{}, err } - if idleTimeout < 0 { - return Config{}, fmt.Errorf("TS_IDLE_TIMEOUT must be >= 0, got %v", idleTimeout) - } dialRetries, dialBackoffBase, dialBackoffMax, err := parseDialConfig() if err != nil { @@ -96,6 +95,7 @@ func LoadConfig(verboseFlag bool) (Config, error) { StateDir: os.Getenv("TS_STATE_DIR"), ControlURL: os.Getenv("TS_CONTROL_URL"), ConnectTimeout: timeout, + DialTimeout: dialTimeout, DrainTimeout: drainTimeout, IdleTimeout: idleTimeout, DialRetries: dialRetries, @@ -165,6 +165,29 @@ func parseDialRetries() (int, error) { return n, nil } +// parseTimeoutEnvs collects the two per-connection timeouts (idle + dial) so +// LoadConfig stays under the cyclomatic-complexity threshold. Idle accepts +// 0 (disabled); dial must be strictly positive. +func parseTimeoutEnvs() (idle, dial time.Duration, err error) { + idle, err = parseDurationEnv("TS_IDLE_TIMEOUT", 0) + if err != nil { + return 0, 0, err + } + if idle < 0 { + return 0, 0, fmt.Errorf("TS_IDLE_TIMEOUT must be >= 0, got %v", idle) + } + + dial, err = parseDurationEnv("TS_DIAL_TIMEOUT", defaultDialTimeout) + if err != nil { + return 0, 0, err + } + if dial <= 0 { + return 0, 0, fmt.Errorf("TS_DIAL_TIMEOUT must be > 0, got %v", dial) + } + + return idle, dial, nil +} + // parseDialConfig collects the three ReconnectDialer parameters together so // LoadConfig stays under the cyclomatic-complexity threshold. func parseDialConfig() (retries int, base, maxBackoff time.Duration, err error) { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 81aed66..b2d7e6f 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -258,6 +258,41 @@ func TestLoadConfig(t *testing.T) { env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_IDLE_TIMEOUT": "-1m"}, wantErr: true, }, + { + name: "dial timeout defaults to 5s", + env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123"}, + wantErr: false, + check: func(t *testing.T, cfg Config) { + if cfg.DialTimeout != 5*time.Second { + t.Errorf("expected DialTimeout default 5s, got %v", cfg.DialTimeout) + } + }, + }, + { + name: "dial timeout parsed", + env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "10s"}, + wantErr: false, + check: func(t *testing.T, cfg Config) { + if cfg.DialTimeout != 10*time.Second { + t.Errorf("expected DialTimeout 10s, got %v", cfg.DialTimeout) + } + }, + }, + { + name: "dial timeout zero rejected", + env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "0"}, + wantErr: true, + }, + { + name: "dial timeout negative rejected", + env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "-1s"}, + wantErr: true, + }, + { + name: "dial timeout invalid rejected", + env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "garbage"}, + wantErr: true, + }, { name: "dial retries defaults", env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123"}, @@ -369,7 +404,7 @@ func TestLoadConfig(t *testing.T) { "TS_LOCAL_ADDR", "TS_HOSTNAME", "TS_STATE_DIR", "TS_CONTROL_URL", "TS_MAX_CONNECTIONS", "TS_HEALTH_ADDR", "TS_LOG_FORMAT", "TS_AUTO_INSTANCE", "TS_INSTANCE_NAME", "TS_PORT_RANGE", "TS_MANUAL_MODE", - "TS_DRAIN_TIMEOUT", "TS_IDLE_TIMEOUT", + "TS_DRAIN_TIMEOUT", "TS_IDLE_TIMEOUT", "TS_DIAL_TIMEOUT", "TS_DIAL_RETRIES", "TS_DIAL_BACKOFF_BASE", "TS_DIAL_BACKOFF_MAX"} { os.Unsetenv(key) } diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index a7d23f3..e9d710b 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -96,12 +96,12 @@ func AcceptLoop(listener net.Listener, dialer Dialer, cfg config.Config, wg *syn // Reset backoff on successful accept backoff = backoffMin - // Check connection limit - current := telemetry.GetActiveConnections() - if current >= cfg.MaxConnections { + // Atomically claim a connection slot. The CAS loop in TryClaimConnection + // closes the check-then-act race where a burst of accepts could each + // observe (cur < cap) and all proceed past the limit. + if !telemetry.TryClaimConnection(cfg.MaxConnections) { telemetry.AddRejectedConn() logger.Warn("connection rejected: limit reached", - "current", current, "max", cfg.MaxConnections, "client", conn.RemoteAddr()) _ = conn.Close() @@ -111,6 +111,9 @@ func AcceptLoop(listener net.Listener, dialer Dialer, cfg config.Config, wg *syn wg.Add(1) go func(c net.Conn) { defer wg.Done() + // Release the claimed slot once the per-conn work returns, + // regardless of how it terminates (success, dial fail, panic-free abort). + defer telemetry.AddActiveConnection(-1) handleConn(c, dialer, cfg, logger) }(conn) } @@ -125,10 +128,11 @@ var bufferPool = sync.Pool{ } func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog.Logger) { - // Track metrics - telemetry.AddActiveConnection(1) + // Active-connection slot management lives in AcceptLoop (atomic claim + + // release via defer in the spawned goroutine). We only count the total + // here so direct callers (unit tests that bypass AcceptLoop) still see + // the per-call increment. telemetry.AddTotalConnection() - defer telemetry.AddActiveConnection(-1) addr := client.RemoteAddr().String() connStart := time.Now() @@ -144,7 +148,12 @@ func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog. logger.Info("connection opened", "client", addr) - ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout) + // DialTimeout is intentionally smaller than ConnectTimeout: ConnectTimeout + // covers the one-time tsnet init (control plane handshake, can be slow), + // DialTimeout covers each per-connection target dial. With ReconnectDialer + // retrying, a large ConnectTimeout would let one stuck client hog a slot + // for many minutes; DialTimeout=5s keeps the worst case bounded. + ctx, cancel := context.WithTimeout(context.Background(), cfg.DialTimeout) defer cancel() remote, err := dialer.Dial(ctx, "tcp", cfg.Target) @@ -173,20 +182,37 @@ func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog. "bytes_rx", bytesRx) } -// proxyConnections performs bidirectional copy between client and remote, -// returning the bytes transferred in each direction. -func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) (tx, rx int64) { - var once sync.Once - closeAll := func() { - once.Do(func() { - _ = client.Close() - _ = remote.Close() - }) +// halfCloser is satisfied by net.Conn implementations that support a +// uni-directional write shutdown. *net.TCPConn and tsnet/gonet conns +// both implement this. net.Pipe does not. +type halfCloser interface { + CloseWrite() error +} + +// halfCloseWrite tries to half-close the write side of dst so the peer +// sees EOF on its read while the opposite-direction copy keeps draining. +// Unwraps idleConn to reach the underlying conn before the type assert, +// since idleConn embeds net.Conn and does not promote CloseWrite. +// Returns true if a half-close was actually performed. +func halfCloseWrite(c net.Conn) bool { + if ic, ok := c.(*idleConn); ok { + c = ic.Conn } + if cw, ok := c.(halfCloser); ok { + _ = cw.CloseWrite() + return true + } + return false +} +// proxyConnections performs bidirectional copy between client and remote, +// returning the bytes transferred in each direction. Each direction runs +// in its own goroutine; when one direction's source EOFs, the destination's +// write half is closed so the peer sees end-of-stream without tearing down +// the opposite (still-active) direction. Both ends are fully closed only +// after both directions complete. +func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) (tx, rx int64) { copyConn := func(dst, src net.Conn, direction string, counter *int64) { - defer closeAll() - bufPtr := bufferPool.Get().(*[]byte) defer bufferPool.Put(bufPtr) @@ -195,14 +221,23 @@ func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) switch { case err == nil, IsExpectedCloseError(err): - // Normal close — nothing to log here. + // Graceful end of this direction. Signal EOF to the peer by + // half-closing the write side; the opposite direction keeps + // draining any in-flight bytes. Falls back to full Close() + // when the underlying conn cannot half-close (e.g. net.Pipe). + if !halfCloseWrite(dst) { + _ = dst.Close() + } case isIdleTimeoutErr(err): - // Idle timeout is an operator-configured policy, not a transport - // fault. Log once at info; do not count as error. + // Idle timeout is an operator-configured policy. Log once at + // info; do not count as transport error. Full-close to release + // the slot promptly. logger.Info("connection closed (idle timeout)", "client", addr, "direction", direction, "bytes", n) + _ = dst.Close() + _ = src.Close() default: telemetry.AddError() logger.Warn("copy error", @@ -210,18 +245,28 @@ func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) "direction", direction, "bytes", n, "error", err) + _ = dst.Close() + _ = src.Close() } } var wg sync.WaitGroup - wg.Add(1) + wg.Add(2) go func() { defer wg.Done() copyConn(client, remote, "rx", &rx) }() - copyConn(remote, client, "tx", &tx) + go func() { + defer wg.Done() + copyConn(remote, client, "tx", &tx) + }() wg.Wait() + // Idempotent full close after both directions completed — guarantees + // the conns are released even if a half-close path ran. + _ = client.Close() + _ = remote.Close() + return tx, rx } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 3ca6a1a..c0e406d 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -7,6 +7,7 @@ import ( "log/slog" "net" "sync" + "sync/atomic" "testing" "time" @@ -69,6 +70,7 @@ func TestHandleConnWithDialer(t *testing.T) { cfg := config.Config{ Target: "100.64.0.1:3389", ConnectTimeout: 5 * time.Second, + DialTimeout: 5 * time.Second, } clientConn, proxyConn := net.Pipe() @@ -136,6 +138,7 @@ func TestAcceptLoopWithDialer(t *testing.T) { cfg := config.Config{ Target: "100.64.0.1:3389", ConnectTimeout: 5 * time.Second, + DialTimeout: 5 * time.Second, MaxConnections: 1000, } @@ -341,3 +344,87 @@ func TestAcceptLoopBackoff(t *testing.T) { } } +// halfCloseConn wraps a net.Pipe end and records whether CloseWrite was +// invoked. Used to verify proxyConnections half-closes on graceful EOF. +type halfCloseConn struct { + net.Conn + closeWriteCalled atomic.Bool +} + +func (h *halfCloseConn) CloseWrite() error { + h.closeWriteCalled.Store(true) + // Mirror real TCPConn semantics: after CloseWrite the peer sees EOF + // on its read. With net.Pipe we approximate by fully closing this end + // — the opposite-direction copy will then see EOF on its own Read. + return h.Conn.Close() +} + +func TestHalfCloseWrite_UnwrapsIdleConn(t *testing.T) { + a, b := net.Pipe() + defer a.Close() + defer b.Close() + + hc := &halfCloseConn{Conn: a} + wrapped := withIdleTimeout(hc, 100*time.Millisecond) + + if !halfCloseWrite(wrapped) { + t.Fatal("halfCloseWrite should reach the inner halfCloser through idleConn") + } + if !hc.closeWriteCalled.Load() { + t.Fatal("CloseWrite was not invoked on the underlying conn") + } +} + +func TestHalfCloseWrite_ReturnsFalseForNonHalfCloser(t *testing.T) { + // net.Pipe ends do NOT implement halfCloser. Verify the fallback path. + a, b := net.Pipe() + defer a.Close() + defer b.Close() + if halfCloseWrite(a) { + t.Fatal("halfCloseWrite should return false for non-halfCloser conns") + } +} + +func TestProxyConnections_HalfClosesOnEOF(t *testing.T) { + // Server pair: what the proxy sees as "client" (a) and what tests + // drive (aPeer). Same for remote. + a, aPeer := net.Pipe() + r, rPeer := net.Pipe() + defer aPeer.Close() + defer rPeer.Close() + + clientHC := &halfCloseConn{Conn: a} + remoteHC := &halfCloseConn{Conn: r} + + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + done := make(chan struct{}) + go func() { + _, _ = proxyConnections(clientHC, remoteHC, "test", logger) + close(done) + }() + + // Send one byte from the remote side, then EOF that direction. + go func() { + _, _ = rPeer.Write([]byte("X")) + _ = rPeer.Close() + }() + + // Drain whatever arrives on the client peer; once the EOF propagates + // via half-close, the read returns and we let the other direction wind down. + buf := make([]byte, 16) + _ = aPeer.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, _ = aPeer.Read(buf) + _ = aPeer.Close() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("proxyConnections did not return after both directions ended") + } + + if !clientHC.closeWriteCalled.Load() && !remoteHC.closeWriteCalled.Load() { + t.Error("expected at least one CloseWrite invocation on graceful EOF; none recorded") + } +} + diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go index b4dbd79..5f04ff4 100644 --- a/internal/telemetry/metrics.go +++ b/internal/telemetry/metrics.go @@ -36,6 +36,26 @@ func GetActiveConnections() int64 { return atomic.LoadInt64(&globalMetrics.ActiveConnections) } +// TryClaimConnection atomically increments ActiveConnections only if the +// resulting value would not exceed maxConns. Returns true on successful +// claim. Callers MUST call AddActiveConnection(-1) when the work tied to +// the claim finishes, even on error paths. +// +// Uses a CAS loop to close the check-then-act race that exists when the +// limit check and increment happen in two separate atomic ops. +func TryClaimConnection(maxConns int64) bool { + for { + cur := atomic.LoadInt64(&globalMetrics.ActiveConnections) + if cur >= maxConns { + return false + } + if atomic.CompareAndSwapInt64(&globalMetrics.ActiveConnections, cur, cur+1) { + return true + } + // CAS lost — another goroutine moved the counter, retry. + } +} + // AddTotalConnection increments the total connection count. func AddTotalConnection() { atomic.AddInt64(&globalMetrics.TotalConnections, 1) diff --git a/internal/telemetry/metrics_test.go b/internal/telemetry/metrics_test.go new file mode 100644 index 0000000..b630c8c --- /dev/null +++ b/internal/telemetry/metrics_test.go @@ -0,0 +1,70 @@ +package telemetry + +import ( + "sync" + "sync/atomic" + "testing" +) + +func TestTryClaimConnection_HonorsLimit(t *testing.T) { + ResetMetrics() + + if !TryClaimConnection(1) { + t.Fatal("first claim should succeed") + } + if TryClaimConnection(1) { + t.Fatal("second claim should fail when cap=1 and one slot held") + } + + AddActiveConnection(-1) + + if !TryClaimConnection(1) { + t.Fatal("claim should succeed again after release") + } + AddActiveConnection(-1) +} + +func TestTryClaimConnection_AtomicUnderRace(t *testing.T) { + ResetMetrics() + + const cap = 50 + const attackers = 500 + + var success atomic.Int64 + var wg sync.WaitGroup + wg.Add(attackers) + start := make(chan struct{}) + + for range attackers { + go func() { + defer wg.Done() + <-start + if TryClaimConnection(cap) { + success.Add(1) + } + }() + } + + close(start) + wg.Wait() + + got := success.Load() + if got != cap { + t.Fatalf("under race: %d claims succeeded, want exactly %d", got, cap) + } + if GetActiveConnections() != cap { + t.Errorf("ActiveConnections = %d, want %d", GetActiveConnections(), cap) + } + + // Cleanup so subsequent tests start clean. + for range cap { + AddActiveConnection(-1) + } +} + +func TestTryClaimConnection_ZeroCapAlwaysFails(t *testing.T) { + ResetMetrics() + if TryClaimConnection(0) { + t.Fatal("cap=0 should reject all claims") + } +} diff --git a/main_integration_test.go b/main_integration_test.go index ff61be5..e85dae5 100644 --- a/main_integration_test.go +++ b/main_integration_test.go @@ -308,7 +308,9 @@ func TestConcurrentConnections(t *testing.T) { } } -// TestConnectionLimit tests that connection limits are enforced. +// TestConnectionLimit tests that connection limits are enforced via the +// atomic TryClaimConnection helper. Replaced the old check-then-act +// simulation after PR introducing the CAS-based claim path. func TestConnectionLimit(t *testing.T) { telemetry.ResetMetrics() @@ -316,14 +318,12 @@ func TestConnectionLimit(t *testing.T) { MaxConnections: 2, } - // Simulate connection limit check - for i := 0; i < 5; i++ { - current := telemetry.GetActiveConnections() - if current >= cfg.MaxConnections { + // Try 5 claims; the helper should admit exactly MaxConnections of them + // and reject the rest, atomically. + for range 5 { + if !telemetry.TryClaimConnection(cfg.MaxConnections) { telemetry.AddRejectedConn() - continue } - telemetry.AddActiveConnection(1) } m := telemetry.GetMetrics() diff --git a/site/src/content/docs/getting-started.md b/site/src/content/docs/getting-started.md index 6c03ff2..3805b03 100644 --- a/site/src/content/docs/getting-started.md +++ b/site/src/content/docs/getting-started.md @@ -60,7 +60,8 @@ TS_CONTROL_URL=https://vpn.example.com | `TS_MANUAL_MODE` | `false` | Force legacy persistent mode. Takes precedence over `TS_AUTO_INSTANCE`. | | `TS_INSTANCE_NAME` | _(empty)_ | Stable instance alias for deterministic local port selection. | | `TS_PORT_RANGE` | `33389-34388` | Port range for auto mode (`START-END`). | -| `TS_TIMEOUT` | `30s` | Timeout for Tailscale initialization and dial. Go duration format. | +| `TS_TIMEOUT` | `30s` | Timeout for tsnet initialization (control-plane handshake). Go duration format. | +| `TS_DIAL_TIMEOUT` | `5s` | Per-connection target dial timeout, distinct from `TS_TIMEOUT`. Keeps stuck dials from holding a slot across retries. Go duration format. _(v1.8.0+)_ | | `TS_DRAIN_TIMEOUT` | `15s` | Timeout for graceful drain of active connections on shutdown. Go duration format. | | `TS_MAX_CONNECTIONS` | `1000` | Maximum concurrent connections before rejecting new ones. | | `TS_IDLE_TIMEOUT` | _(disabled)_ | Close connections after this period of no traffic in either direction. Go duration format (e.g. `30m`). Default `0` disables. Useful for reclaiming slots from abandoned RDP sessions. _(v1.6.0+)_ |