Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ TS_TARGET=100.x.x.x:3389
# TS_LOCAL_ADDR=127.0.0.1:33389 # Local bind address
# TS_CONTROL_URL= # Custom control plane (e.g. https://vpn.example.com)
# TS_IDLE_TIMEOUT= # Close idle conns after this duration (e.g. 30m). Default: disabled.
# TS_DIAL_TIMEOUT=5s # Per-connection target dial timeout (distinct from TS_TIMEOUT which only covers tsnet init).
# TS_DIAL_RETRIES=3 # Max retries for transient target dial failures (0 disables).
# TS_DIAL_BACKOFF_BASE=1s # Base backoff for retries (multiplied by 2^attempt).
# TS_DIAL_BACKOFF_MAX=30s # Cap on backoff duration per retry.
31 changes: 27 additions & 4 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ const (
defaultDialRetries = 3
defaultDialBackoffBase = 1 * time.Second
defaultDialBackoffMax = 30 * time.Second
defaultDialTimeout = 5 * time.Second
)

// Config holds the bridge configuration.
Expand All @@ -35,6 +36,7 @@ type Config struct {
StateDir string
ControlURL string
ConnectTimeout time.Duration
DialTimeout time.Duration
DrainTimeout time.Duration
IdleTimeout time.Duration
DialRetries int
Expand Down Expand Up @@ -70,13 +72,10 @@ func LoadConfig(verboseFlag bool) (Config, error) {
return Config{}, err
}

idleTimeout, err := parseDurationEnv("TS_IDLE_TIMEOUT", 0)
idleTimeout, dialTimeout, err := parseTimeoutEnvs()
if err != nil {
return Config{}, err
}
if idleTimeout < 0 {
return Config{}, fmt.Errorf("TS_IDLE_TIMEOUT must be >= 0, got %v", idleTimeout)
}

dialRetries, dialBackoffBase, dialBackoffMax, err := parseDialConfig()
if err != nil {
Expand All @@ -96,6 +95,7 @@ func LoadConfig(verboseFlag bool) (Config, error) {
StateDir: os.Getenv("TS_STATE_DIR"),
ControlURL: os.Getenv("TS_CONTROL_URL"),
ConnectTimeout: timeout,
DialTimeout: dialTimeout,
DrainTimeout: drainTimeout,
IdleTimeout: idleTimeout,
DialRetries: dialRetries,
Expand Down Expand Up @@ -165,6 +165,29 @@ func parseDialRetries() (int, error) {
return n, nil
}

// parseTimeoutEnvs collects the two per-connection timeouts (idle + dial) so
// LoadConfig stays under the cyclomatic-complexity threshold. Idle accepts
// 0 (disabled); dial must be strictly positive.
func parseTimeoutEnvs() (idle, dial time.Duration, err error) {
idle, err = parseDurationEnv("TS_IDLE_TIMEOUT", 0)
if err != nil {
return 0, 0, err
}
if idle < 0 {
return 0, 0, fmt.Errorf("TS_IDLE_TIMEOUT must be >= 0, got %v", idle)
}

dial, err = parseDurationEnv("TS_DIAL_TIMEOUT", defaultDialTimeout)
if err != nil {
return 0, 0, err
}
if dial <= 0 {
return 0, 0, fmt.Errorf("TS_DIAL_TIMEOUT must be > 0, got %v", dial)
}

return idle, dial, nil
}

// parseDialConfig collects the three ReconnectDialer parameters together so
// LoadConfig stays under the cyclomatic-complexity threshold.
func parseDialConfig() (retries int, base, maxBackoff time.Duration, err error) {
Expand Down
37 changes: 36 additions & 1 deletion internal/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,41 @@ func TestLoadConfig(t *testing.T) {
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_IDLE_TIMEOUT": "-1m"},
wantErr: true,
},
{
name: "dial timeout defaults to 5s",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123"},
wantErr: false,
check: func(t *testing.T, cfg Config) {
if cfg.DialTimeout != 5*time.Second {
t.Errorf("expected DialTimeout default 5s, got %v", cfg.DialTimeout)
}
},
},
{
name: "dial timeout parsed",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "10s"},
wantErr: false,
check: func(t *testing.T, cfg Config) {
if cfg.DialTimeout != 10*time.Second {
t.Errorf("expected DialTimeout 10s, got %v", cfg.DialTimeout)
}
},
},
{
name: "dial timeout zero rejected",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "0"},
wantErr: true,
},
{
name: "dial timeout negative rejected",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "-1s"},
wantErr: true,
},
{
name: "dial timeout invalid rejected",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123", "TS_DIAL_TIMEOUT": "garbage"},
wantErr: true,
},
{
name: "dial retries defaults",
env: map[string]string{"TS_TARGET": "100.64.0.1:3389", "TS_AUTHKEY": "tskey-auth-test123"},
Expand Down Expand Up @@ -369,7 +404,7 @@ func TestLoadConfig(t *testing.T) {
"TS_LOCAL_ADDR", "TS_HOSTNAME", "TS_STATE_DIR", "TS_CONTROL_URL",
"TS_MAX_CONNECTIONS", "TS_HEALTH_ADDR", "TS_LOG_FORMAT",
"TS_AUTO_INSTANCE", "TS_INSTANCE_NAME", "TS_PORT_RANGE", "TS_MANUAL_MODE",
"TS_DRAIN_TIMEOUT", "TS_IDLE_TIMEOUT",
"TS_DRAIN_TIMEOUT", "TS_IDLE_TIMEOUT", "TS_DIAL_TIMEOUT",
"TS_DIAL_RETRIES", "TS_DIAL_BACKOFF_BASE", "TS_DIAL_BACKOFF_MAX"} {
os.Unsetenv(key)
}
Expand Down
93 changes: 69 additions & 24 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ func AcceptLoop(listener net.Listener, dialer Dialer, cfg config.Config, wg *syn
// Reset backoff on successful accept
backoff = backoffMin

// Check connection limit
current := telemetry.GetActiveConnections()
if current >= cfg.MaxConnections {
// Atomically claim a connection slot. The CAS loop in TryClaimConnection
// closes the check-then-act race where a burst of accepts could each
// observe (cur < cap) and all proceed past the limit.
if !telemetry.TryClaimConnection(cfg.MaxConnections) {
telemetry.AddRejectedConn()
logger.Warn("connection rejected: limit reached",
"current", current,
"max", cfg.MaxConnections,
"client", conn.RemoteAddr())
_ = conn.Close()
Expand All @@ -111,6 +111,9 @@ func AcceptLoop(listener net.Listener, dialer Dialer, cfg config.Config, wg *syn
wg.Add(1)
go func(c net.Conn) {
defer wg.Done()
// Release the claimed slot once the per-conn work returns,
// regardless of how it terminates (success, dial fail, panic-free abort).
defer telemetry.AddActiveConnection(-1)
handleConn(c, dialer, cfg, logger)
}(conn)
}
Expand All @@ -125,10 +128,11 @@ var bufferPool = sync.Pool{
}

func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog.Logger) {
// Track metrics
telemetry.AddActiveConnection(1)
// Active-connection slot management lives in AcceptLoop (atomic claim +
// release via defer in the spawned goroutine). We only count the total
// here so direct callers (unit tests that bypass AcceptLoop) still see
// the per-call increment.
telemetry.AddTotalConnection()
defer telemetry.AddActiveConnection(-1)

addr := client.RemoteAddr().String()
connStart := time.Now()
Expand All @@ -144,7 +148,12 @@ func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog.

logger.Info("connection opened", "client", addr)

ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout)
// DialTimeout is intentionally smaller than ConnectTimeout: ConnectTimeout
// covers the one-time tsnet init (control plane handshake, can be slow),
// DialTimeout covers each per-connection target dial. With ReconnectDialer
// retrying, a large ConnectTimeout would let one stuck client hog a slot
// for many minutes; DialTimeout=5s keeps the worst case bounded.
ctx, cancel := context.WithTimeout(context.Background(), cfg.DialTimeout)
defer cancel()

remote, err := dialer.Dial(ctx, "tcp", cfg.Target)
Expand Down Expand Up @@ -173,20 +182,37 @@ func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog.
"bytes_rx", bytesRx)
}

// proxyConnections performs bidirectional copy between client and remote,
// returning the bytes transferred in each direction.
func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) (tx, rx int64) {
var once sync.Once
closeAll := func() {
once.Do(func() {
_ = client.Close()
_ = remote.Close()
})
// halfCloser is satisfied by net.Conn implementations that support a
// uni-directional write shutdown. *net.TCPConn and tsnet/gonet conns
// both implement this. net.Pipe does not.
type halfCloser interface {
CloseWrite() error
}

// halfCloseWrite tries to half-close the write side of dst so the peer
// sees EOF on its read while the opposite-direction copy keeps draining.
// Unwraps idleConn to reach the underlying conn before the type assert,
// since idleConn embeds net.Conn and does not promote CloseWrite.
// Returns true if a half-close was actually performed.
func halfCloseWrite(c net.Conn) bool {
if ic, ok := c.(*idleConn); ok {
c = ic.Conn
}
if cw, ok := c.(halfCloser); ok {
_ = cw.CloseWrite()
return true
}
return false
}

// proxyConnections performs bidirectional copy between client and remote,
// returning the bytes transferred in each direction. Each direction runs
// in its own goroutine; when one direction's source EOFs, the destination's
// write half is closed so the peer sees end-of-stream without tearing down
// the opposite (still-active) direction. Both ends are fully closed only
// after both directions complete.
func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) (tx, rx int64) {
copyConn := func(dst, src net.Conn, direction string, counter *int64) {
defer closeAll()

bufPtr := bufferPool.Get().(*[]byte)
defer bufferPool.Put(bufPtr)

Expand All @@ -195,33 +221,52 @@ func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger)

switch {
case err == nil, IsExpectedCloseError(err):
// Normal close — nothing to log here.
// Graceful end of this direction. Signal EOF to the peer by
// half-closing the write side; the opposite direction keeps
// draining any in-flight bytes. Falls back to full Close()
// when the underlying conn cannot half-close (e.g. net.Pipe).
if !halfCloseWrite(dst) {
_ = dst.Close()
}
case isIdleTimeoutErr(err):
// Idle timeout is an operator-configured policy, not a transport
// fault. Log once at info; do not count as error.
// Idle timeout is an operator-configured policy. Log once at
// info; do not count as transport error. Full-close to release
// the slot promptly.
logger.Info("connection closed (idle timeout)",
"client", addr,
"direction", direction,
"bytes", n)
_ = dst.Close()
_ = src.Close()
default:
telemetry.AddError()
logger.Warn("copy error",
"client", addr,
"direction", direction,
"bytes", n,
"error", err)
_ = dst.Close()
_ = src.Close()
}
}

var wg sync.WaitGroup
wg.Add(1)
wg.Add(2)
go func() {
defer wg.Done()
copyConn(client, remote, "rx", &rx)
}()
copyConn(remote, client, "tx", &tx)
go func() {
defer wg.Done()
copyConn(remote, client, "tx", &tx)
}()
wg.Wait()

// Idempotent full close after both directions completed — guarantees
// the conns are released even if a half-close path ran.
_ = client.Close()
_ = remote.Close()

return tx, rx
}

Expand Down
Loading
Loading