diff --git a/.gitignore b/.gitignore index 98a0b8d..af43e7c 100644 --- a/.gitignore +++ b/.gitignore @@ -50,3 +50,5 @@ AGENTS.md # Testing coverage.out coverage.html +site/.astro/ +site/node_modules/ diff --git a/README.md b/README.md index eb426d3..2fa7505 100644 --- a/README.md +++ b/README.md @@ -1,346 +1,114 @@ -# ts-bridge - -TCP bridge for tunneling connections through Tailscale's encrypted mesh network without requiring administrator privileges on the client machine. - -## Overview - -Connect via RDP/SSH from a **non-admin machine** to an **admin machine** through restrictive firewalls using Tailscale's userspace networking. +[![CI](https://github.com/mlorentedev/ts-bridge/actions/workflows/ci.yml/badge.svg)](https://github.com/mlorentedev/ts-bridge/actions/workflows/ci.yml) +[![Go Version](https://img.shields.io/github/go-mod/go-version/mlorentedev/ts-bridge)](https://go.dev/) +[![Docs](https://img.shields.io/badge/docs-live-brightgreen)](https://mlorentedev.github.io/ts-bridge/) +[![License](https://img.shields.io/badge/license-MIT-green)](LICENSE) -## Current Status (2026-02-24) - -- **Windows workflow validated**: auto mode default, alias-based launch, and bootstrap setup script. -- **CI security checks passing target state**: gosec G115 overflow warning fixed in port-selection arithmetic. -- **Runtime hardening in progress**: Windows cleanup/log-noise improvements implemented; field validation continues. -- **Linux parity validation pending**: waiting for Linux client access to complete end-to-end checks. +# ts-bridge -| Machine | Admin Rights | Tailscale | Role | -|---------|--------------|-----------|------| -| **Client** | No | Not installed (uses tsnet) | Initiates connection | -| **Host** | Yes | Installed natively | Receives connection | +On-demand Tailscale TCP bridge for non-admin machines. Connect to remote resources securely from locked-down environments. -## Why ts-bridge? +## The Problem -| Requirement | Native Tailscale | ts-bridge | -|-------------|------------------|-----------| -| Admin rights on client | **Yes** | **No** | -| Kernel module | Yes | No (userspace) | -| Software installation | Required | Portable binary | -| Leaves traces | Yes | No (ephemeral) | -| Works on locked-down machines | No | **Yes** | +Working with secure networks often requires VPNs like Tailscale. However, native Tailscale clients require administrator privileges to install and create persistent network interfaces. In many enterprise, corporate, or locked-down environments, users do not have admin rights on their client machines, completely blocking access to critical remote resources via Tailscale. -## Control Plane Support +## The Solution -ts-bridge works with both **Tailscale SaaS** (default) and **self-hosted [Headscale](https://github.com/juanfont/headscale)**. +ts-bridge runs a full, standalone Tailscale node purely in userspace using `tsnet`. It acts as a local proxy, forwarding TCP traffic (like RDP, SSH, or HTTP) through the encrypted mesh network. -| | Tailscale SaaS | Headscale | +| | Native Tailscale | ts-bridge | |---|---|---| -| **Setup** | Default, no extra config | Set `TS_CONTROL_URL` | -| **Auth key prefix** | `tskey-auth-*` | `hskey-auth-*` | -| **Ephemeral cleanup** | Automatic | Requires `--ephemeral` flag on the pre-auth key | -| **Minimum version** | Any | Headscale v0.28.0+ requires tsnet >= v1.74 (ts-bridge v1.3.0+) | - -### Headscale Quick Setup - -```bash -# .env -TS_CONTROL_URL=https://vpn.example.com -TS_AUTHKEY=hskey-auth-xxxxx -TS_TARGET=100.64.0.5:3389 -``` - -Generate the auth key on the Headscale server: -```bash -headscale preauthkeys create --user --reusable --ephemeral --expiration 8760h -``` - -> **Important:** The `--ephemeral` flag must be on the **pre-auth key**, not just in ts-bridge config. -> Without it, nodes persist as offline entries after disconnect. - -## Quick Start - -### 1. Host Machine (Admin Rights Required) - -Install [Tailscale](https://tailscale.com/download) normally, then run the setup script: - -```powershell -# Run as Administrator -cd scripts\host -PowerShell -ExecutionPolicy Bypass -File .\setup.ps1 -``` - -Note the Tailscale IP shown (e.g., `100.82.151.104`). For Headscale, use `tailscale up --login-server=https://vpn.example.com`. See [Host Setup Guide](#host-setup-guide) for manual steps and troubleshooting. - -### 2. Client Machine (No Admin Rights) - -Download from [Releases](https://github.com/mlorentedev/ts-bridge/releases), extract, and configure: - -```bash -tar -xzf ts-bridge-linux-amd64.tar.gz -cd ts-bridge-linux-amd64 -cp .env.example .env -``` - -Edit `.env` — only two variables are required: - -```bash -TS_AUTHKEY=tskey-auth-kXXXXXXXXX # From Tailscale admin or Headscale (hskey-auth-*) -TS_TARGET=100.82.151.104:3389 # Host's Tailscale/Headscale IP + RDP port -``` - -### 3. Run - -```bash -# Linux/macOS -./scripts/client/run.sh - -# Windows -PowerShell -ExecutionPolicy Bypass -File .\scripts\client\run.ps1 -``` - -### 4. Connect - -```bash -# Use the local port shown in ts-bridge banner (auto mode picks it) -# Linux -xfreerdp /v:127.0.0.1: /u:Username /cert:ignore - -# Windows -mstsc /v:127.0.0.1: - -# macOS (Microsoft Remote Desktop) -# Add PC → 127.0.0.1: -``` - -RDP concurrency is enforced by the target host OS/policy (for example, many Windows desktop editions allow only one interactive session at a time). - -## Host Setup Guide - -The host machine (the one you connect **to**) needs specific configuration to accept RDP connections over Tailscale. - -### Requirements - -| Requirement | Details | -|-------------|---------| -| **Windows Edition** | Pro, Enterprise, Education, or Server. **Home edition cannot host RDP.** | -| **Tailscale** | Installed and connected to the tailnet | -| **Admin rights** | Needed for initial setup only | - -### Step 1: Enable Remote Desktop - -Settings > System > Remote Desktop > Toggle **On**. +| **Admin rights on client** | Required | **None needed** | +| **Kernel footprint** | persistent TUN/TAP | **Zero** (userspace) | +| **Installation** | System package | **Portable binary** | +| **Node persistence** | Remains on tailnet | **Ephemeral** (auto-deletes) | -Or via PowerShell (Admin): -```powershell -Set-ItemProperty -Path 'HKLM:\System\CurrentControlSet\Control\Terminal Server' ` - -Name "fDenyTSConnections" -Value 0 -``` - -### Step 2: Configure Authentication - -**Network Level Authentication (NLA)** is enabled by default — keep it on. - -The account you connect with **must have a traditional password set**. The following do NOT work for RDP: - -| Auth Method | Works? | Fix | -|-------------|--------|-----| -| Local account + password | Yes | — | -| Microsoft account + password | Yes | Username: `MicrosoftAccount\user@outlook.com` | -| Microsoft account (passwordless) | **No** | Set a password at account.microsoft.com > Security | -| Windows Hello PIN | **No** | Use password instead | -| Blank/empty password | **No** | Set a password on the account | - -### Step 3: Configure Firewall - -The automated `setup.ps1` handles this. For manual setup: - -```powershell -# Enable built-in RDP rules -Enable-NetFirewallRule -DisplayGroup "Remote Desktop" - -# Restrict RDP to Tailscale subnet only (recommended) -New-NetFirewallRule -DisplayName "Allow RDP over Tailscale" ` - -Direction Inbound -Protocol TCP -LocalPort 3389 ` - -RemoteAddress 100.64.0.0/10 -Action Allow -Profile Private -``` - -### Step 4: Tailscale Configuration +## Quick Install -```powershell -# Enable unattended mode (stays connected without user logged in) -tailscale up --unattended - -# Verify Tailscale IP -tailscale ip -4 -``` - -**Recommended:** Disable key expiry for the host machine in the [Tailscale admin console](https://login.tailscale.com/admin/machines) so it doesn't silently drop off the tailnet. - -### Common Host Issues - -| Symptom | Cause | Fix | -|---------|-------|-----| -| "Your Home edition doesn't support Remote Desktop" | Windows Home | Upgrade to Pro or use a different host | -| RDP connection refused | Firewall blocking Tailscale subnet | Create explicit rule for `100.64.0.0/10` on TCP 3389 | -| "CredSSP encryption oracle" error | Mismatched Windows Update levels | Patch both client and host to latest | -| Connection drops after host reboot | Tailscale not running as service | Verify Tailscale service: `Get-Service Tailscale` | -| "The credentials did not work" | Passwordless Microsoft account | Set a traditional password | -| Works locally but not over Tailscale | Tailscale ACLs restricting access | Check [ACL rules](https://login.tailscale.com/admin/acls) allow TCP 3389 | -| Third-party antivirus blocking | AV conflicts with Tailscale WFP rules | Add Tailscale exception in AV settings | - -## Configuration Reference +### 1. Client Machine (No Admin) -Create `.env` from `.env.example`. Only `TS_AUTHKEY` and `TS_TARGET` are required — everything else has sensible defaults. - -### Required - -| Variable | Description | Example | -|----------|-------------|---------| -| `TS_AUTHKEY` | Auth key. Tailscale SaaS: [generate here](https://login.tailscale.com/admin/settings/keys). Headscale: `headscale preauthkeys create`. Prefix: `tskey-` or `hskey-`. | `tskey-auth-kXXXXXX` | -| `TS_TARGET` | Host address on the mesh network. Supports IP or MagicDNS hostname. | `100.82.151.104:3389` or `my-desktop:3389` | - -### Optional - -| Variable | Default | Description | Example | -|----------|---------|-------------|---------| -| `TS_LOCAL_ADDR` | `127.0.0.1:33389` | Local address to bind the bridge listener. Auto-derived in auto mode when unset. | `127.0.0.1:43389` | -| `TS_CONTROL_URL` | _(Tailscale default)_ | Custom control plane URL. Set this to use a self-hosted [Headscale](https://github.com/juanfont/headscale) server. | `https://vpn.example.com` | -| `TS_HOSTNAME` | `ts-bridge` | Node name in the admin console. Auto-generated per run in auto mode when unset. | `bridge-workpc` | -| `TS_STATE_DIR` | `./ts-state` | Directory for node state. Auto-created with `0700` permissions. Ephemeral temp dir in auto mode when unset. | `/tmp/ts-bridge-state` | -| `TS_AUTO_INSTANCE` | `true` | Auto mode toggle (`false` disables auto behavior). | `false` | -| `TS_MANUAL_MODE` | `false` | Force legacy persistent/manual behavior (`true` takes precedence over `TS_AUTO_INSTANCE`). | `true` | -| `TS_INSTANCE_NAME` | _(empty)_ | Stable instance alias used for deterministic local port selection. | `office-laptop` | -| `TS_PORT_RANGE` | `33389-34388` | Port range used by auto mode (`START-END`). | `61000-61100` | -| `TS_TIMEOUT` | `30s` | Timeout for Tailscale initialization and dial. Go duration format. | `1m`, `45s` | -| `TS_MAX_CONNECTIONS` | `1000` | Maximum concurrent connections before rejecting new ones. | `50` | -| `TS_HEALTH_ADDR` | _(disabled)_ | Address for health/metrics HTTP server. | `127.0.0.1:8080` | -| `TS_VERBOSE` | `false` | Enable debug logging. Also available as `-v` flag. | `true` | -| `TS_LOG_FORMAT` | `text` | Log output format. | `text` or `json` | - -### Minimal `.env` (low friction) +Download the binary from [Releases](https://github.com/mlorentedev/ts-bridge/releases) and create a `.env` file: ```env -TS_AUTHKEY=tskey-auth-... -TS_TARGET=100.x.x.x:3389 +TS_AUTHKEY=tskey-auth-kXXXXXXXXX # From Tailscale admin panel +TS_TARGET=100.82.151.104:3389 # Host's Tailscale IP + RDP port TS_INSTANCE_NAME=office-laptop ``` -### Bootstrap per OS (recommended) +### 2. Host Setup (Admin) -```bash -# Linux/macOS -./scripts/client/bootstrap.sh --authkey tskey-auth-... --target 100.x.x.x:3389 --instance office-laptop -``` +Ensure Tailscale is running on the target machine and RDP is enabled. The repository includes an automated PowerShell script: ```powershell -# Windows -PowerShell -ExecutionPolicy Bypass -File .\scripts\client\bootstrap.ps1 -AuthKey tskey-auth-... -Target 100.x.x.x:3389 -Instance office-laptop +# Run as Administrator +cd scripts\host +PowerShell -ExecutionPolicy Bypass -File .\setup.ps1 ``` -### Auto Mode (default) +## What You Get -Auto mode is enabled by default and is recommended for multi-device usage with minimal setup friction. +| Feature | Description | +|---|---| +| **Zero-Admin VPN** | Connect from heavily restricted laptops without filing an IT ticket. | +| **Headscale Support** | Compatible with open-source control planes (via `TS_CONTROL_URL`). | +| **Multi-Instance** | Run multiple bridges concurrently to connect to different machines. | +| **Ephemeral by Default** | Leaves no trace. The node is automatically removed from the network when the bridge closes. | -```bash -# .env -TS_INSTANCE_NAME=office-laptop -TS_PORT_RANGE=33389-34388 -``` - -To force legacy persistent/manual behavior: +## Before/After (The Workflow) +### Before (Native Tailscale on locked-down PC) ```bash -# .env -TS_MANUAL_MODE=true +> tailscale up +Error: Administrator privilege is required to install or start the Tailscale service. ``` -Optional alias override when launching: - +### After (ts-bridge) ```bash -# Linux/macOS -./scripts/client/run.sh --instance office-laptop - -# Windows -PowerShell -ExecutionPolicy Bypass -File .\scripts\client\run.ps1 -Instance office-laptop +> ./ts-bridge + +---------------------------------------+ + | TAILSCALE BRIDGE v1.3.0 | + +---------------------------------------+ + | Host: tsb-office-laptop-a1b2c3 | + | Local: 127.0.0.1:33389 | + | Target: 100.82.151.104:3389 | + +---------------------------------------+ + Waiting for connections... ``` +Now, connect locally: `mstsc /v:127.0.0.1:33389` -In auto mode (with related vars unset), ts-bridge derives a deterministic local port, generates a unique hostname per run, and uses an ephemeral state directory. -On Windows shutdown, ephemeral cleanup is retried briefly to reduce transient "directory is not empty" races. +## Configuration -### Health Endpoint - -When `TS_HEALTH_ADDR` is set: - -```bash -curl http://127.0.0.1:8080/health/live # {"status":"ok"} — process alive -curl http://127.0.0.1:8080/health/ready # {"status":"ok"} — tsnet tunnel up -curl http://127.0.0.1:8080/metrics # Connection stats (JSON) -``` - -### Command Line +| Variable | Default | Description | +|---|---|---| +| `TS_AUTHKEY` | — | **Required**. Tailscale/Headscale auth key. | +| `TS_TARGET` | — | **Required**. Target IP/hostname and port (e.g., `100.x.x.x:3389`). | +| `TS_INSTANCE_NAME` | — | Optional alias to derive a stable local port. | +| `TS_LOCAL_ADDR` | Auto | Force a specific local address (e.g., `127.0.0.1:4000`). | +| `TS_CONTROL_URL` | — | Set to your Headscale server URL if not using Tailscale SaaS. | -```bash -./ts-bridge -version # Show version -./ts-bridge -v # Verbose logging (same as TS_VERBOSE=true) -``` +For advanced configuration (timeouts, limits, legacy modes), see the [Full Documentation](https://mlorentedev.github.io/ts-bridge/). -## How It Works +## Architecture ```text -┌─────────────────────────────────────────────────────────────────────┐ -│ CLIENT (Non-Admin) │ -│ │ -│ RDP Client ──▶ ts-bridge ──▶ tsnet (userspace WireGuard) │ -│ 127.0.0.1: No admin required │ -│ │ -│ ┌─────────────────────────────────────────────────┐ │ -│ │ Firewall: UDP blocked, HTTPS allowed │◀── Tunnels │ -│ └─────────────────────────────────────────────────┘ via DERP │ -└─────────────────────────────────────────────────────────────────────┘ - │ - │ Tailscale Network (WireGuard encrypted) - ▼ -┌─────────────────────────────────────────────────────────────────────┐ -│ HOST (Admin) │ -│ │ -│ Tailscale (Native) ──▶ RDP Server :3389 │ -│ 100.x.x.x │ -└─────────────────────────────────────────────────────────────────────┘ +┌─────────────────────────┐ +│ CLIENT (Non-Admin) │ +│ RDP/SSH → :33389 │ +│ ↓ │ +│ ts-bridge (userspace) │ +└────┬────────────────────┘ + │ encrypted via WireGuard (DERP/STUN) +┌────▼────────────────────┐ +│ HOST (Admin) │ +│ Tailscale (native) │ +│ ↓ │ +│ RDP/SSH Server │ +└─────────────────────────┘ ``` -1. ts-bridge creates ephemeral node via [tsnet](https://pkg.go.dev/tailscale.com/tsnet) on Tailscale SaaS or Headscale -2. WireGuard runs in userspace (no kernel module, no admin) -3. If UDP blocked, uses DERP relay over HTTPS -4. All traffic end-to-end encrypted (WireGuard + RDP TLS) -5. Node auto-deletes on exit (Headscale: requires `--ephemeral` pre-auth key) - -## Limitations - -| Limitation | Impact | Mitigation | -|------------|--------|------------| -| TCP only | No UDP (VoIP, games) | Use for RDP, SSH, HTTP | -| DERP latency | +50-200ms when relayed | Acceptable for RDP | -| Auth key expiry | Default 90 days (Tailscale), configurable (Headscale) | Use long-lived keys or `--expiration 8760h` on Headscale | -| Single target | One host per instance | Run multiple instances with different `TS_TARGET` | -| Windows Home | Cannot host RDP | Use Windows Pro/Enterprise on host | -| RDP host policy | Concurrent sessions may be limited by OS edition | Use multiple hosts or RDS-enabled server setup | - -## Security - -- **No admin footprint**: Runs entirely in userspace -- **Ephemeral nodes**: Auto-delete on exit (Headscale: requires `--ephemeral` pre-auth key) -- **E2E encryption**: WireGuard encryption even through DERP relay -- **Local only**: Binds to `127.0.0.1` by default -- **Secure state**: Directory created with `0700` permissions -- **CI security checks**: Port-selection arithmetic hardened to satisfy gosec overflow checks - -## Documentation - -- [Contributing](CONTRIBUTING.md) - Development setup, testing, releases - -## Support +## Contributing -For questions, bugs, or feature requests, please [open an issue](https://github.com/mlorentedev/ts-bridge/issues). +See [CONTRIBUTING.md](CONTRIBUTING.md) for development setup, testing, and PR guidelines. ## License -[MIT](LICENSE) +MIT — see [LICENSE](LICENSE). diff --git a/go.mod b/go.mod index ab4dc0c..267463c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module ts-bridge -go 1.25 +go 1.24 -toolchain go1.25.0 +toolchain go1.24.0 require tailscale.com v1.80.0 diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..d2a84fc --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,342 @@ +package config + +import ( + "errors" + "fmt" + "hash/fnv" + "net" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +const ( + // Default runtime values. + defaultLocalAddr = "127.0.0.1:33389" + defaultHostname = "ts-bridge" + defaultStateDir = "./ts-state" + defaultAutoPortRange = "33389-34388" + defaultTimeout = 30 * time.Second + defaultDrainTimeout = 15 * time.Second + defaultMaxConnections = 1000 +) + +// Config holds the bridge configuration. +type Config struct { + LocalAddr string + Target string + AuthKey string // #nosec G117 -- internal struct, never serialized + Hostname string + StateDir string + ControlURL string + ConnectTimeout time.Duration + DrainTimeout time.Duration + MaxConnections int64 + HealthAddr string + Verbose bool + LogFormat string + AutoInstance bool + EphemeralState bool +} + +// LoadConfig parses environment variables into a Config struct. +func LoadConfig(verboseFlag bool) (Config, error) { + target, err := parseTarget() + if err != nil { + return Config{}, err + } + + authKey, err := parseAuthKey() + if err != nil { + return Config{}, err + } + + timeout, err := parseDurationEnv("TS_TIMEOUT", defaultTimeout) + if err != nil { + return Config{}, err + } + + drainTimeout, err := parseDurationEnv("TS_DRAIN_TIMEOUT", defaultDrainTimeout) + if err != nil { + return Config{}, err + } + + maxConns, err := parseInt64Env("TS_MAX_CONNECTIONS", defaultMaxConnections) + if err != nil { + return Config{}, err + } + + cfg := Config{ + LocalAddr: os.Getenv("TS_LOCAL_ADDR"), + Target: target, + AuthKey: authKey, + Hostname: os.Getenv("TS_HOSTNAME"), + StateDir: os.Getenv("TS_STATE_DIR"), + ControlURL: os.Getenv("TS_CONTROL_URL"), + ConnectTimeout: timeout, + DrainTimeout: drainTimeout, + MaxConnections: maxConns, + HealthAddr: os.Getenv("TS_HEALTH_ADDR"), + Verbose: verboseFlag || parseBoolEnv(os.Getenv("TS_VERBOSE")), + LogFormat: EnvOr("TS_LOG_FORMAT", "text"), + } + + if err := applyAutoInstanceConfig(&cfg); err != nil { + return Config{}, err + } + + if cfg.LocalAddr == "" { + cfg.LocalAddr = defaultLocalAddr + } + if cfg.Hostname == "" { + cfg.Hostname = defaultHostname + } + if cfg.StateDir == "" { + cfg.StateDir = defaultStateDir + } + + return cfg, nil +} + +func parseDurationEnv(key string, fallback time.Duration) (time.Duration, error) { + v := os.Getenv(key) + if v == "" { + return fallback, nil + } + d, err := time.ParseDuration(v) + if err != nil { + return 0, fmt.Errorf("%s invalid: %w", key, err) + } + return d, nil +} + +func parseInt64Env(key string, fallback int64) (int64, error) { + v := os.Getenv(key) + if v == "" { + return fallback, nil + } + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n < 1 { + return 0, fmt.Errorf("%s invalid: %w", key, err) + } + return n, nil +} + + + +func parseTarget() (string, error) { + target := os.Getenv("TS_TARGET") + if target == "" { + return "", errors.New("TS_TARGET is required (format: HOST:PORT)") + } + + host, portStr, err := net.SplitHostPort(target) + if err != nil { + return "", fmt.Errorf("TS_TARGET invalid format: %w", err) + } + if host == "" { + return "", errors.New("TS_TARGET: host cannot be empty") + } + port, err := strconv.Atoi(portStr) + if err != nil || port < 1 || port > 65535 { + return "", fmt.Errorf("TS_TARGET: invalid port %q", portStr) + } + return target, nil +} + +func parseAuthKey() (string, error) { + authKey := os.Getenv("TS_AUTHKEY") + if authKey == "" { + return "", errors.New("TS_AUTHKEY is required") + } + if !strings.HasPrefix(authKey, "tskey-") && !strings.HasPrefix(authKey, "hskey-") { + return "", errors.New("TS_AUTHKEY: invalid format (must start with tskey- or hskey-)") + } + return authKey, nil +} + +// EnvOr returns the environment variable or a fallback. +func EnvOr(key, fallback string) string { + if v := os.Getenv(key); v != "" { + return v + } + return fallback +} + +func applyAutoInstanceConfig(cfg *Config) error { + cfg.AutoInstance = shouldEnableAutoInstance() + if !cfg.AutoInstance { + return nil + } + + instanceName := os.Getenv("TS_INSTANCE_NAME") + portRange := EnvOr("TS_PORT_RANGE", defaultAutoPortRange) + + if cfg.LocalAddr == "" { + localAddr, err := deriveAutoLocalAddr(cfg.Target, instanceName, portRange) + if err != nil { + return err + } + cfg.LocalAddr = localAddr + } + + if cfg.Hostname == "" { + cfg.Hostname = deriveAutoHostname(cfg.Target, instanceName) + } + + if cfg.StateDir == "" { + cfg.StateDir = filepath.Join(os.TempDir(), "ts-bridge", cfg.Hostname) + cfg.EphemeralState = true + } + + return nil +} + +func shouldEnableAutoInstance() bool { + if parseBoolEnv(os.Getenv("TS_MANUAL_MODE")) { + return false + } + + rawAutoMode := strings.TrimSpace(os.Getenv("TS_AUTO_INSTANCE")) + if rawAutoMode == "" { + return true + } + + return parseBoolEnv(rawAutoMode) +} + +func parseBoolEnv(value string) bool { + switch strings.ToLower(strings.TrimSpace(value)) { + case "1", "true", "yes", "on": + return true + default: + return false + } +} + +func deriveAutoLocalAddr(target, instanceName, portRange string) (string, error) { + start, end, err := parsePortRange(portRange) + if err != nil { + return "", err + } + + hostName, err := os.Hostname() + if err != nil || hostName == "" { + hostName = "unknown-host" + } + + seed := fmt.Sprintf("%s|%s|%s", hostName, target, instanceName) + port, err := selectAvailablePort(seed, start, end) + if err != nil { + return "", err + } + + return fmt.Sprintf("127.0.0.1:%d", port), nil +} + +func deriveAutoHostname(target, instanceName string) string { + hostName, err := os.Hostname() + if err != nil || hostName == "" { + hostName = "unknown-host" + } + + machine := sanitizeHostnameLabel(hostName) + instance := sanitizeHostnameLabel(instanceName) + if instance == "" { + instance = machine + } + if instance == "" { + instance = "bridge" + } + + base := "tsb-" + instance + if len(base) > 30 { + base = strings.Trim(base[:30], "-") + } + if base == "" { + base = "tsb-bridge" + } + + hasher := fnv.New32a() + _, _ = hasher.Write([]byte(machine + "|" + target + "|" + instanceName)) + hash := fmt.Sprintf("%06x", hasher.Sum32()&0xffffff) + + hostname := fmt.Sprintf("%s-%s-%d", base, hash, os.Getpid()) + if len(hostname) > 63 { + hostname = strings.Trim(hostname[:63], "-") + } + if hostname == "" { + return defaultHostname + } + return hostname +} + +func sanitizeHostnameLabel(value string) string { + value = strings.ToLower(strings.TrimSpace(value)) + var b strings.Builder + previousDash := false + + for _, r := range value { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { + b.WriteRune(r) + previousDash = false + continue + } + if !previousDash { + b.WriteByte('-') + previousDash = true + } + } + + return strings.Trim(b.String(), "-") +} + +func parsePortRange(value string) (int, int, error) { + parts := strings.Split(value, "-") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid format %q (expected START-END)", value) + } + + start, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid start port: %w", err) + } + + end, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid end port: %w", err) + } + + if start < 1 || end > 65535 || start > end { + return 0, 0, fmt.Errorf("TS_PORT_RANGE out of bounds: %d-%d", start, end) + } + + return start, end, nil +} + +func selectAvailablePort(seed string, start, end int) (int, error) { + span := end - start + 1 + if span <= 0 { + return 0, fmt.Errorf("TS_PORT_RANGE has invalid span: %d", span) + } + + hasher := fnv.New32a() + _, _ = hasher.Write([]byte(seed)) + offset := int(int64(hasher.Sum32()) % int64(span)) + + for i := 0; i < span; i++ { + port := start + ((offset + i) % span) + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + continue + } + if err := listener.Close(); err != nil { + continue + } + return port, nil + } + + return 0, fmt.Errorf("TS_PORT_RANGE has no free ports in %d-%d", start, end) +} diff --git a/main_test.go b/internal/config/config_test.go similarity index 55% rename from main_test.go rename to internal/config/config_test.go index 2800f06..5dd7361 100644 --- a/main_test.go +++ b/internal/config/config_test.go @@ -1,208 +1,12 @@ -package main +package config import ( - "context" - "errors" - "fmt" - "io" - "log/slog" - "net" "os" "strings" - "sync/atomic" "testing" "time" ) -// mockDialer implements Dialer for testing without tsnet. -type mockDialer struct { - dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) -} - -func (m *mockDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { - return m.dialFunc(ctx, network, addr) -} - -// TestDialerInterfaceSatisfaction verifies that mockDialer satisfies the Dialer interface. -// This is a compile-time check: if Dialer doesn't exist or has a different signature, this fails. -var _ Dialer = (*mockDialer)(nil) - -func TestHandleConnWithDialer(t *testing.T) { - initLogger(Config{LogFormat: "text"}) - - tests := []struct { - name string - dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) - wantErrors int64 - wantTotalConn int64 - }{ - { - name: "successful proxy via dialer", - dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { - // Return a pipe that immediately closes (simulates short-lived connection) - server, client := net.Pipe() - go func() { - // Echo one read then close - buf := make([]byte, 1024) - n, _ := server.Read(buf) - if n > 0 { - _, _ = server.Write(buf[:n]) - } - server.Close() - }() - return client, nil - }, - wantErrors: 0, - wantTotalConn: 1, - }, - { - name: "dial failure increments errors", - dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { - return nil, errors.New("connection refused") - }, - wantErrors: 1, - wantTotalConn: 1, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Reset metrics - oldMetrics := metrics - metrics = Metrics{} - defer func() { metrics = oldMetrics }() - - dialer := &mockDialer{dialFunc: tt.dialFunc} - cfg := Config{ - Target: "100.64.0.1:3389", - ConnectTimeout: 5 * time.Second, - } - - // Create a client connection via pipe - clientConn, proxyConn := net.Pipe() - defer clientConn.Close() - - // Run handleConn in goroutine (it blocks until proxy finishes) - done := make(chan struct{}) - go func() { - handleConn(proxyConn, dialer, cfg) - close(done) - }() - - if tt.wantErrors == 0 { - // Send data through the proxy - _, _ = clientConn.Write([]byte("HELLO")) - - buf := make([]byte, 1024) - _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, err := clientConn.Read(buf) - if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("read from proxy failed: %v", err) - } - if n > 0 && string(buf[:n]) != "HELLO" { - t.Errorf("expected echo HELLO, got %q", buf[:n]) - } - } - - // Close client side to let handleConn finish - clientConn.Close() - - select { - case <-done: - case <-time.After(3 * time.Second): - t.Fatal("handleConn did not finish in time") - } - - gotErrors := atomic.LoadInt64(&metrics.TotalErrors) - if gotErrors != tt.wantErrors { - t.Errorf("TotalErrors = %d, want %d", gotErrors, tt.wantErrors) - } - - gotTotal := atomic.LoadInt64(&metrics.TotalConnections) - if gotTotal != tt.wantTotalConn { - t.Errorf("TotalConnections = %d, want %d", gotTotal, tt.wantTotalConn) - } - }) - } -} - -func TestAcceptLoopWithDialer(t *testing.T) { - initLogger(Config{LogFormat: "text"}) - - // Snapshot metrics before test to check delta after - connsBefore := atomic.LoadInt64(&metrics.TotalConnections) - - // Mock dialer that echoes data - dialer := &mockDialer{ - dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { - server, client := net.Pipe() - go func() { - defer server.Close() - buf := make([]byte, 1024) - n, _ := server.Read(buf) - if n > 0 { - _, _ = server.Write(buf[:n]) - } - }() - return client, nil - }, - } - - cfg := Config{ - Target: "100.64.0.1:3389", - ConnectTimeout: 5 * time.Second, - MaxConnections: 1000, - } - - listener, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatalf("listen failed: %v", err) - } - - // Run accept loop in background - loopDone := make(chan error, 1) - go func() { - loopDone <- acceptLoop(listener, dialer, cfg) - }() - - // Connect a client through the accept loop - conn, err := net.Dial("tcp", listener.Addr().String()) - if err != nil { - t.Fatalf("dial failed: %v", err) - } - - _, _ = conn.Write([]byte("TEST")) - - buf := make([]byte, 1024) - _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, err := conn.Read(buf) - if err != nil && !errors.Is(err, io.EOF) { - t.Fatalf("read failed: %v", err) - } - if string(buf[:n]) != "TEST" { - t.Errorf("expected TEST, got %q", buf[:n]) - } - conn.Close() - - // Close listener to stop accept loop - listener.Close() - - select { - case err := <-loopDone: - if err != nil { - t.Errorf("acceptLoop returned error: %v", err) - } - case <-time.After(3 * time.Second): - t.Fatal("acceptLoop did not stop") - } - - // Check that at least one connection was handled (use atomic reads, no struct reset) - connsAfter := atomic.LoadInt64(&metrics.TotalConnections) - if connsAfter <= connsBefore { - t.Errorf("expected TotalConnections to increase, before=%d after=%d", connsBefore, connsAfter) - } -} - func TestLoadConfig(t *testing.T) { tests := []struct { name string @@ -458,19 +262,17 @@ func TestLoadConfig(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Clear all config env vars for _, key := range []string{"TS_TARGET", "TS_AUTHKEY", "TS_TIMEOUT", "TS_VERBOSE", "TS_LOCAL_ADDR", "TS_HOSTNAME", "TS_STATE_DIR", "TS_CONTROL_URL", "TS_MAX_CONNECTIONS", "TS_HEALTH_ADDR", "TS_LOG_FORMAT", - "TS_AUTO_INSTANCE", "TS_INSTANCE_NAME", "TS_PORT_RANGE", "TS_MANUAL_MODE"} { + "TS_AUTO_INSTANCE", "TS_INSTANCE_NAME", "TS_PORT_RANGE", "TS_MANUAL_MODE", "TS_DRAIN_TIMEOUT"} { os.Unsetenv(key) } - // Set test-specific env vars for k, v := range tt.env { os.Setenv(k, v) } - cfg, err := loadConfig(tt.verbose) + cfg, err := LoadConfig(tt.verbose) if tt.wantErr { if err == nil { @@ -488,58 +290,6 @@ func TestLoadConfig(t *testing.T) { } } -func TestInitLogger(t *testing.T) { - oldLogger := logger - defer func() { logger = oldLogger }() - - tests := []struct { - name string - cfg Config - wantHandler string - wantLevel slog.Level - }{ - { - name: "default text handler", - cfg: Config{LogFormat: "text"}, - wantHandler: "*slog.TextHandler", - wantLevel: slog.LevelInfo, - }, - { - name: "json handler", - cfg: Config{LogFormat: "json"}, - wantHandler: "*slog.JSONHandler", - wantLevel: slog.LevelInfo, - }, - { - name: "verbose enables debug level", - cfg: Config{LogFormat: "text", Verbose: true}, - wantHandler: "*slog.TextHandler", - wantLevel: slog.LevelDebug, - }, - { - name: "unknown format falls back to text", - cfg: Config{LogFormat: "yaml"}, - wantHandler: "*slog.TextHandler", - wantLevel: slog.LevelInfo, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - initLogger(tt.cfg) - - handlerType := fmt.Sprintf("%T", logger.Handler()) - if handlerType != tt.wantHandler { - t.Errorf("handler type = %s, want %s", handlerType, tt.wantHandler) - } - - if !logger.Handler().Enabled(context.Background(), tt.wantLevel) { - t.Errorf("expected level %v to be enabled", tt.wantLevel) - } - }) - } -} - func TestEnvOr(t *testing.T) { tests := []struct { name string @@ -560,31 +310,8 @@ func TestEnvOr(t *testing.T) { os.Setenv(tt.key, tt.envValue) defer os.Unsetenv(tt.key) } - if got := envOr(tt.key, tt.fallback); got != tt.want { - t.Errorf("envOr(%q, %q) = %q, want %q", tt.key, tt.fallback, got, tt.want) - } - }) - } -} - -func TestIsRetryableCleanupError(t *testing.T) { - tests := []struct { - name string - err error - want bool - }{ - {name: "nil error", err: nil, want: false}, - {name: "directory not empty", err: errors.New("The directory is not empty."), want: true}, - {name: "access denied", err: errors.New("Access is denied."), want: true}, - {name: "resource busy", err: errors.New("device or resource busy"), want: true}, - {name: "non retryable", err: errors.New("invalid argument"), want: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := isRetryableCleanupError(tt.err) - if got != tt.want { - t.Errorf("isRetryableCleanupError(%v) = %v, want %v", tt.err, got, tt.want) + if got := EnvOr(tt.key, tt.fallback); got != tt.want { + t.Errorf("EnvOr(%q, %q) = %q, want %q", tt.key, tt.fallback, got, tt.want) } }) } diff --git a/internal/health/health.go b/internal/health/health.go new file mode 100644 index 0000000..926ccc4 --- /dev/null +++ b/internal/health/health.go @@ -0,0 +1,54 @@ +package health + +import ( + "encoding/json" + "errors" + "log/slog" + "net/http" + "sync/atomic" + "time" + "ts-bridge/internal/telemetry" +) + +// StartServer initializes and runs the health and metrics HTTP server. +func StartServer(addr string, ready *atomic.Bool, logger *slog.Logger) *http.Server { + mux := http.NewServeMux() + + mux.HandleFunc("/health/live", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + }) + + mux.HandleFunc("/health/ready", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if !ready.Load() { + w.WriteHeader(http.StatusServiceUnavailable) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "not_ready"}) + return + } + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) + }) + + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + snapshot := telemetry.GetMetrics() + _ = json.NewEncoder(w).Encode(snapshot) + }) + + server := &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + } + + go func() { + logger.Info("health server starting", "addr", addr) + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + logger.Error("health server error", "error", err) + } + }() + + return server +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..b3331c0 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,206 @@ +package proxy + +import ( + "context" + "errors" + "io" + "log/slog" + "net" + "strings" + "sync" + "syscall" + "time" + + "ts-bridge/internal/config" + "ts-bridge/internal/telemetry" +) + +const ( + // bufferSize is the size of the copy buffer. 32KB chosen as sweet spot + // for RDP traffic: large enough for efficiency, small enough to avoid + // memory pressure with many concurrent connections. + bufferSize = 32 * 1024 + + // keepAliveInterval for TCP connections. 3 minutes is standard for + // most NAT/firewall idle timeouts. + keepAliveInterval = 3 * time.Minute + + // backoffMin/Max for accept loop error recovery. + backoffMin = 100 * time.Millisecond + backoffMax = 10 * time.Second +) + +// Dialer abstracts the remote connection mechanism. +// tsnet.Server satisfies this interface without an adapter. +type Dialer interface { + Dial(ctx context.Context, network, addr string) (net.Conn, error) +} + +// AcceptLoop accepts incoming connections and routes them to the dialer. +func AcceptLoop(listener net.Listener, dialer Dialer, cfg config.Config, wg *sync.WaitGroup, logger *slog.Logger) error { + backoff := backoffMin + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + + logger.Warn("accept error", "error", err, "backoff", backoff) + time.Sleep(backoff) + backoff = min(backoff*2, backoffMax) + continue + } + + // Reset backoff on successful accept + backoff = backoffMin + + // Check connection limit + current := telemetry.GetActiveConnections() + if current >= cfg.MaxConnections { + telemetry.AddRejectedConn() + logger.Warn("connection rejected: limit reached", + "current", current, + "max", cfg.MaxConnections, + "client", conn.RemoteAddr()) + _ = conn.Close() + continue + } + + wg.Add(1) + go func(c net.Conn) { + defer wg.Done() + handleConn(c, dialer, cfg, logger) + }(conn) + } +} + +// Buffer pool to reduce GC pressure. +var bufferPool = sync.Pool{ + New: func() any { + b := make([]byte, bufferSize) + return &b + }, +} + +func handleConn(client net.Conn, dialer Dialer, cfg config.Config, logger *slog.Logger) { + // Track metrics + telemetry.AddActiveConnection(1) + telemetry.AddTotalConnection() + defer telemetry.AddActiveConnection(-1) + + addr := client.RemoteAddr().String() + connStart := time.Now() + + if tcpConn, ok := client.(*net.TCPConn); ok { + if err := tcpConn.SetKeepAlive(true); err != nil { + logger.Debug("failed to set keepalive", "error", err) + } + if err := tcpConn.SetKeepAlivePeriod(keepAliveInterval); err != nil { + logger.Debug("failed to set keepalive period", "error", err) + } + } + + logger.Info("connection opened", "client", addr) + + ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout) + defer cancel() + + remote, err := dialer.Dial(ctx, "tcp", cfg.Target) + if err != nil { + telemetry.AddError() + logger.Error("dial failed", "client", addr, "target", cfg.Target, "error", err) + _ = client.Close() + return + } + + logger.Debug("tunnel established", "client", addr, "target", cfg.Target) + + bytesTx, bytesRx := proxyConnections(client, remote, addr, logger) + + telemetry.AddBytesTx(bytesTx) + telemetry.AddBytesRx(bytesRx) + + duration := time.Since(connStart) + logger.Info("connection closed", + "client", addr, + "duration", duration, + "bytes_tx", bytesTx, + "bytes_rx", bytesRx) +} + +// proxyConnections performs bidirectional copy between client and remote, +// returning the bytes transferred in each direction. +func proxyConnections(client, remote net.Conn, addr string, logger *slog.Logger) (tx, rx int64) { + var once sync.Once + closeAll := func() { + once.Do(func() { + _ = client.Close() + _ = remote.Close() + }) + } + + copyConn := func(dst, src net.Conn, direction string, counter *int64) { + defer closeAll() + + bufPtr := bufferPool.Get().(*[]byte) + defer bufferPool.Put(bufPtr) + + n, err := io.CopyBuffer(dst, src, *bufPtr) + *counter = n + + if err != nil && !IsExpectedCloseError(err) { + telemetry.AddError() + logger.Warn("copy error", + "client", addr, + "direction", direction, + "bytes", n, + "error", err) + } + } + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + copyConn(client, remote, "rx", &rx) + }() + copyConn(remote, client, "tx", &tx) + wg.Wait() + + return tx, rx +} + +// IsExpectedCloseError returns true for errors that occur during normal connection close. +func IsExpectedCloseError(err error) bool { + if err == nil { + return true + } + + if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) { + return true + } + + // Check for common syscall errors during close + if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.EPIPE) || errors.Is(err, syscall.ENOTCONN) { + return true + } + + // Fallback for error messages (cross-platform compatibility) + errStr := strings.ToLower(err.Error()) + expectedMsgs := []string{ + "use of closed network connection", + "connection reset by peer", + "forcibly closed by the remote host", + "closed pipe", + } + + for _, msg := range expectedMsgs { + if strings.Contains(errStr, msg) { + return true + } + } + + return false +} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go new file mode 100644 index 0000000..16687c5 --- /dev/null +++ b/internal/proxy/proxy_test.go @@ -0,0 +1,231 @@ +package proxy + +import ( + "context" + "errors" + "io" + "log/slog" + "net" + "sync" + "testing" + "time" + + "ts-bridge/internal/config" + "ts-bridge/internal/telemetry" +) + +// mockDialer implements Dialer for testing without tsnet. +type mockDialer struct { + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) +} + +func (m *mockDialer) Dial(ctx context.Context, network, addr string) (net.Conn, error) { + return m.dialFunc(ctx, network, addr) +} + +var _ Dialer = (*mockDialer)(nil) + +func TestHandleConnWithDialer(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + tests := []struct { + name string + dialFunc func(ctx context.Context, network, addr string) (net.Conn, error) + wantErrors int64 + wantTotalConn int64 + }{ + { + name: "successful proxy via dialer", + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + server, client := net.Pipe() + go func() { + buf := make([]byte, 1024) + n, _ := server.Read(buf) + if n > 0 { + _, _ = server.Write(buf[:n]) + } + server.Close() + }() + return client, nil + }, + wantErrors: 0, + wantTotalConn: 1, + }, + { + name: "dial failure increments errors", + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nil, errors.New("connection refused") + }, + wantErrors: 1, + wantTotalConn: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + telemetry.ResetMetrics() + + dialer := &mockDialer{dialFunc: tt.dialFunc} + cfg := config.Config{ + Target: "100.64.0.1:3389", + ConnectTimeout: 5 * time.Second, + } + + clientConn, proxyConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + handleConn(proxyConn, dialer, cfg, logger) + close(done) + }() + + if tt.wantErrors == 0 { + _, _ = clientConn.Write([]byte("HELLO")) + + buf := make([]byte, 1024) + _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := clientConn.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("read from proxy failed: %v", err) + } + if n > 0 && string(buf[:n]) != "HELLO" { + t.Errorf("expected echo HELLO, got %q", buf[:n]) + } + } + + clientConn.Close() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("handleConn did not finish in time") + } + + m := telemetry.GetMetrics() + if m.TotalErrors != tt.wantErrors { + t.Errorf("TotalErrors = %d, want %d", m.TotalErrors, tt.wantErrors) + } + + if m.TotalConnections != tt.wantTotalConn { + t.Errorf("TotalConnections = %d, want %d", m.TotalConnections, tt.wantTotalConn) + } + }) + } +} + +func TestAcceptLoopWithDialer(t *testing.T) { + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + telemetry.ResetMetrics() + + dialer := &mockDialer{ + dialFunc: func(ctx context.Context, network, addr string) (net.Conn, error) { + server, client := net.Pipe() + go func() { + defer server.Close() + buf := make([]byte, 1024) + n, _ := server.Read(buf) + if n > 0 { + _, _ = server.Write(buf[:n]) + } + }() + return client, nil + }, + } + + cfg := config.Config{ + Target: "100.64.0.1:3389", + ConnectTimeout: 5 * time.Second, + MaxConnections: 1000, + } + + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen failed: %v", err) + } + + var wg sync.WaitGroup + loopDone := make(chan error, 1) + go func() { + loopDone <- AcceptLoop(listener, dialer, cfg, &wg, logger) + }() + + conn, err := net.Dial("tcp", listener.Addr().String()) + if err != nil { + t.Fatalf("dial failed: %v", err) + } + + _, _ = conn.Write([]byte("TEST")) + + buf := make([]byte, 1024) + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, err := conn.Read(buf) + if err != nil && !errors.Is(err, io.EOF) { + t.Fatalf("read failed: %v", err) + } + if string(buf[:n]) != "TEST" { + t.Errorf("expected TEST, got %q", buf[:n]) + } + conn.Close() + listener.Close() + + select { + case err := <-loopDone: + if err != nil { + t.Errorf("acceptLoop returned error: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatal("acceptLoop did not stop") + } + + m := telemetry.GetMetrics() + if m.TotalConnections <= 0 { + t.Errorf("expected TotalConnections to increase") + } +} + +func TestIsExpectedCloseError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {name: "nil error", err: nil, want: true}, + {name: "EOF", err: io.EOF, want: true}, + {name: "net.ErrClosed", err: net.ErrClosed, want: true}, + {name: "random error", err: errors.New("some error"), want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsExpectedCloseError(tt.err) + if got != tt.want { + t.Errorf("IsExpectedCloseError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestAcceptLoopBackoff(t *testing.T) { + backoff := backoffMin + + // Simulate 5 consecutive failures + for i := 0; i < 5; i++ { + backoff = min(backoff*2, backoffMax) + } + + // After 5 doublings: 100ms -> 200ms -> 400ms -> 800ms -> 1600ms -> 3200ms + expected := 3200 * time.Millisecond + if backoff != expected { + t.Errorf("backoff after 5 failures = %v, expected %v", backoff, expected) + } + + // Verify max cap + for i := 0; i < 10; i++ { + backoff = min(backoff*2, backoffMax) + } + if backoff != backoffMax { + t.Errorf("backoff should cap at %v, got %v", backoffMax, backoff) + } +} + diff --git a/internal/telemetry/metrics.go b/internal/telemetry/metrics.go new file mode 100644 index 0000000..b4dbd79 --- /dev/null +++ b/internal/telemetry/metrics.go @@ -0,0 +1,67 @@ +package telemetry + +import "sync/atomic" + +// Metrics tracks operational statistics. +type Metrics struct { + ActiveConnections int64 `json:"active_connections"` + TotalConnections int64 `json:"total_connections"` + TotalBytesTx int64 `json:"total_bytes_tx"` + TotalBytesRx int64 `json:"total_bytes_rx"` + TotalErrors int64 `json:"total_errors"` + RejectedConns int64 `json:"rejected_connections"` +} + +var globalMetrics Metrics + +// GetMetrics returns a snapshot of the current metrics. +func GetMetrics() Metrics { + return Metrics{ + ActiveConnections: atomic.LoadInt64(&globalMetrics.ActiveConnections), + TotalConnections: atomic.LoadInt64(&globalMetrics.TotalConnections), + TotalBytesTx: atomic.LoadInt64(&globalMetrics.TotalBytesTx), + TotalBytesRx: atomic.LoadInt64(&globalMetrics.TotalBytesRx), + TotalErrors: atomic.LoadInt64(&globalMetrics.TotalErrors), + RejectedConns: atomic.LoadInt64(&globalMetrics.RejectedConns), + } +} + +// AddActiveConnection increments the active connection count. +func AddActiveConnection(n int64) { + atomic.AddInt64(&globalMetrics.ActiveConnections, n) +} + +// GetActiveConnections returns the active connection count. +func GetActiveConnections() int64 { + return atomic.LoadInt64(&globalMetrics.ActiveConnections) +} + +// AddTotalConnection increments the total connection count. +func AddTotalConnection() { + atomic.AddInt64(&globalMetrics.TotalConnections, 1) +} + +// AddBytesTx adds to the total bytes transmitted. +func AddBytesTx(n int64) { + atomic.AddInt64(&globalMetrics.TotalBytesTx, n) +} + +// AddBytesRx adds to the total bytes received. +func AddBytesRx(n int64) { + atomic.AddInt64(&globalMetrics.TotalBytesRx, n) +} + +// AddError increments the total error count. +func AddError() { + atomic.AddInt64(&globalMetrics.TotalErrors, 1) +} + +// AddRejectedConn increments the rejected connection count. +func AddRejectedConn() { + atomic.AddInt64(&globalMetrics.RejectedConns, 1) +} + +// ResetMetrics is used for testing. +func ResetMetrics() { + globalMetrics = Metrics{} +} diff --git a/main.go b/main.go index 97de070..62912bf 100644 --- a/main.go +++ b/main.go @@ -1,24 +1,15 @@ -// Package main implements a TCP bridge over Tailscale's mesh network. -// It creates an ephemeral tsnet node and forwards local connections -// to a remote target through Tailscale's encrypted tunnel. package main import ( "context" - "encoding/json" - "errors" "flag" "fmt" - "hash/fnv" - "io" "log/slog" "net" "net/http" "os" "os/signal" - "path/filepath" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -26,6 +17,10 @@ import ( "time" "tailscale.com/tsnet" + + "ts-bridge/internal/config" + "ts-bridge/internal/health" + "ts-bridge/internal/proxy" ) // Build-time variables set via ldflags. @@ -34,74 +29,12 @@ var ( commit = "unknown" ) -// Constants for tuning. const ( - // bufferSize is the size of the copy buffer. 32KB chosen as sweet spot - // for RDP traffic: large enough for efficiency, small enough to avoid - // memory pressure with many concurrent connections. - bufferSize = 32 * 1024 - - // keepAliveInterval for TCP connections. 3 minutes is standard for - // most NAT/firewall idle timeouts. - keepAliveInterval = 3 * time.Minute - - // defaultTimeout for tsnet initialization and dial operations. - defaultTimeout = 30 * time.Second - - // defaultMaxConnections prevents resource exhaustion. - defaultMaxConnections = 1000 - - // Default runtime values. - defaultLocalAddr = "127.0.0.1:33389" - defaultHostname = "ts-bridge" - defaultStateDir = "./ts-state" - defaultAutoPortRange = "33389-34388" cleanupMaxAttempts = 5 cleanupRetryDelay = 150 * time.Millisecond - - // backoffMin/Max for accept loop error recovery. - backoffMin = 100 * time.Millisecond - backoffMax = 10 * time.Second - - // stateDirPerms ensures state directory is only readable by owner. - stateDirPerms = 0700 + stateDirPerms = 0700 ) -// Dialer abstracts the remote connection mechanism. -// tsnet.Server satisfies this interface without an adapter. -type Dialer interface { - Dial(ctx context.Context, network, addr string) (net.Conn, error) -} - -// Config holds the bridge configuration. -type Config struct { - LocalAddr string - Target string - AuthKey string // #nosec G117 -- internal struct, never serialized - Hostname string - StateDir string - ControlURL string - ConnectTimeout time.Duration - MaxConnections int64 - HealthAddr string - Verbose bool - LogFormat string - AutoInstance bool - EphemeralState bool -} - -// Metrics tracks operational statistics. -type Metrics struct { - ActiveConnections int64 `json:"active_connections"` - TotalConnections int64 `json:"total_connections"` - TotalBytesTx int64 `json:"total_bytes_tx"` - TotalBytesRx int64 `json:"total_bytes_rx"` - TotalErrors int64 `json:"total_errors"` - RejectedConns int64 `json:"rejected_connections"` -} - -var metrics Metrics - // Logger is the global structured logger. var logger *slog.Logger @@ -116,7 +49,7 @@ func main() { os.Exit(0) } - cfg, err := loadConfig(*verbose) + cfg, err := config.LoadConfig(*verbose) if err != nil { fmt.Fprintf(os.Stderr, "Config error: %v\n", err) os.Exit(1) @@ -131,7 +64,7 @@ func main() { } } -func initLogger(cfg Config) { +func initLogger(cfg config.Config) { var handler slog.Handler opts := &slog.HandlerOptions{ Level: slog.LevelInfo, @@ -148,282 +81,6 @@ func initLogger(cfg Config) { logger = slog.New(handler) } -func loadConfig(verboseFlag bool) (Config, error) { - target, err := parseTarget() - if err != nil { - return Config{}, err - } - - authKey, err := parseAuthKey() - if err != nil { - return Config{}, err - } - - timeout := defaultTimeout - if t := os.Getenv("TS_TIMEOUT"); t != "" { - d, err := time.ParseDuration(t) - if err != nil { - return Config{}, fmt.Errorf("TS_TIMEOUT invalid: %w", err) - } - timeout = d - } - - maxConns := int64(defaultMaxConnections) - if m := os.Getenv("TS_MAX_CONNECTIONS"); m != "" { - n, err := strconv.ParseInt(m, 10, 64) - if err != nil || n < 1 { - return Config{}, fmt.Errorf("TS_MAX_CONNECTIONS invalid: %w", err) - } - maxConns = n - } - - verbose := verboseFlag || os.Getenv("TS_VERBOSE") == "true" || os.Getenv("TS_VERBOSE") == "1" - - cfg := Config{ - LocalAddr: os.Getenv("TS_LOCAL_ADDR"), - Target: target, - AuthKey: authKey, - Hostname: os.Getenv("TS_HOSTNAME"), - StateDir: os.Getenv("TS_STATE_DIR"), - ControlURL: os.Getenv("TS_CONTROL_URL"), - ConnectTimeout: timeout, - MaxConnections: maxConns, - HealthAddr: os.Getenv("TS_HEALTH_ADDR"), - Verbose: verbose, - LogFormat: envOr("TS_LOG_FORMAT", "text"), - } - - if err := applyAutoInstanceConfig(&cfg); err != nil { - return Config{}, err - } - - if cfg.LocalAddr == "" { - cfg.LocalAddr = defaultLocalAddr - } - if cfg.Hostname == "" { - cfg.Hostname = defaultHostname - } - if cfg.StateDir == "" { - cfg.StateDir = defaultStateDir - } - - return cfg, nil -} - -func parseTarget() (string, error) { - target := os.Getenv("TS_TARGET") - if target == "" { - return "", errors.New("TS_TARGET is required (format: HOST:PORT)") - } - - host, portStr, err := net.SplitHostPort(target) - if err != nil { - return "", fmt.Errorf("TS_TARGET invalid format: %w", err) - } - if host == "" { - return "", errors.New("TS_TARGET: host cannot be empty") - } - port, err := strconv.Atoi(portStr) - if err != nil || port < 1 || port > 65535 { - return "", fmt.Errorf("TS_TARGET: invalid port %q", portStr) - } - return target, nil -} - -func parseAuthKey() (string, error) { - authKey := os.Getenv("TS_AUTHKEY") - if authKey == "" { - return "", errors.New("TS_AUTHKEY is required") - } - if !strings.HasPrefix(authKey, "tskey-") && !strings.HasPrefix(authKey, "hskey-") { - return "", errors.New("TS_AUTHKEY: invalid format (must start with tskey- or hskey-)") - } - return authKey, nil -} - -func envOr(key, fallback string) string { - if v := os.Getenv(key); v != "" { - return v - } - return fallback -} - -func applyAutoInstanceConfig(cfg *Config) error { - cfg.AutoInstance = shouldEnableAutoInstance() - if !cfg.AutoInstance { - return nil - } - - instanceName := os.Getenv("TS_INSTANCE_NAME") - portRange := envOr("TS_PORT_RANGE", defaultAutoPortRange) - - if cfg.LocalAddr == "" { - localAddr, err := deriveAutoLocalAddr(cfg.Target, instanceName, portRange) - if err != nil { - return err - } - cfg.LocalAddr = localAddr - } - - if cfg.Hostname == "" { - cfg.Hostname = deriveAutoHostname(cfg.Target, instanceName) - } - - if cfg.StateDir == "" { - cfg.StateDir = filepath.Join(os.TempDir(), "ts-bridge", cfg.Hostname) - cfg.EphemeralState = true - } - - return nil -} - -func shouldEnableAutoInstance() bool { - if parseBoolEnv(os.Getenv("TS_MANUAL_MODE")) { - return false - } - - rawAutoMode := strings.TrimSpace(os.Getenv("TS_AUTO_INSTANCE")) - if rawAutoMode == "" { - return true - } - - return parseBoolEnv(rawAutoMode) -} - -func parseBoolEnv(value string) bool { - switch strings.ToLower(strings.TrimSpace(value)) { - case "1", "true", "yes", "on": - return true - default: - return false - } -} - -func deriveAutoLocalAddr(target, instanceName, portRange string) (string, error) { - start, end, err := parsePortRange(portRange) - if err != nil { - return "", err - } - - hostName, err := os.Hostname() - if err != nil || hostName == "" { - hostName = "unknown-host" - } - - seed := fmt.Sprintf("%s|%s|%s", hostName, target, instanceName) - port, err := selectAvailablePort(seed, start, end) - if err != nil { - return "", err - } - - return fmt.Sprintf("127.0.0.1:%d", port), nil -} - -func deriveAutoHostname(target, instanceName string) string { - hostName, err := os.Hostname() - if err != nil || hostName == "" { - hostName = "unknown-host" - } - - machine := sanitizeHostnameLabel(hostName) - instance := sanitizeHostnameLabel(instanceName) - if instance == "" { - instance = machine - } - if instance == "" { - instance = "bridge" - } - - base := "tsb-" + instance - if len(base) > 30 { - base = strings.Trim(base[:30], "-") - } - if base == "" { - base = "tsb-bridge" - } - - hasher := fnv.New32a() - _, _ = hasher.Write([]byte(machine + "|" + target + "|" + instanceName)) - hash := fmt.Sprintf("%06x", hasher.Sum32()&0xffffff) - - hostname := fmt.Sprintf("%s-%s-%d", base, hash, os.Getpid()) - if len(hostname) > 63 { - hostname = strings.Trim(hostname[:63], "-") - } - if hostname == "" { - return defaultHostname - } - return hostname -} - -func sanitizeHostnameLabel(value string) string { - value = strings.ToLower(strings.TrimSpace(value)) - var b strings.Builder - previousDash := false - - for _, r := range value { - if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') { - b.WriteRune(r) - previousDash = false - continue - } - if !previousDash { - b.WriteByte('-') - previousDash = true - } - } - - return strings.Trim(b.String(), "-") -} - -func parsePortRange(value string) (int, int, error) { - parts := strings.Split(value, "-") - if len(parts) != 2 { - return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid format %q (expected START-END)", value) - } - - start, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err != nil { - return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid start port: %w", err) - } - - end, err := strconv.Atoi(strings.TrimSpace(parts[1])) - if err != nil { - return 0, 0, fmt.Errorf("TS_PORT_RANGE invalid end port: %w", err) - } - - if start < 1 || end > 65535 || start > end { - return 0, 0, fmt.Errorf("TS_PORT_RANGE out of bounds: %d-%d", start, end) - } - - return start, end, nil -} - -func selectAvailablePort(seed string, start, end int) (int, error) { - span := end - start + 1 - if span <= 0 { - return 0, fmt.Errorf("TS_PORT_RANGE has invalid span: %d", span) - } - - hasher := fnv.New32a() - _, _ = hasher.Write([]byte(seed)) - offset := int(int64(hasher.Sum32()) % int64(span)) - - for i := 0; i < span; i++ { - port := start + ((offset + i) % span) - listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) - if err != nil { - continue - } - if err := listener.Close(); err != nil { - continue - } - return port, nil - } - - return 0, fmt.Errorf("TS_PORT_RANGE has no free ports in %d-%d", start, end) -} - func ensureStateDir(dir string) error { info, err := os.Stat(dir) if os.IsNotExist(err) { @@ -483,7 +140,7 @@ func isRetryableCleanupError(err error) bool { strings.Contains(errStr, "device or resource busy") } -func run(cfg Config) error { +func run(cfg config.Config) error { if err := ensureStateDir(cfg.StateDir); err != nil { return err } @@ -535,8 +192,9 @@ func run(cfg Config) error { // Start health server if configured var ready atomic.Bool var healthServer *http.Server + if cfg.HealthAddr != "" { - healthServer = startHealthServer(cfg.HealthAddr, &ready) + healthServer = health.StartServer(cfg.HealthAddr, &ready, logger) } printBanner(cfg) @@ -558,65 +216,39 @@ func run(cfg Config) error { logger.Error("error closing health server", "error", err) } } - if err := server.Close(); err != nil { - logger.Error("error closing tsnet server", "error", err) - } }() ready.Store(true) - return acceptLoop(listener, server, cfg) -} - -func startHealthServer(addr string, ready *atomic.Bool) *http.Server { - mux := http.NewServeMux() - - mux.HandleFunc("/health/live", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) - }) - - mux.HandleFunc("/health/ready", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if !ready.Load() { - w.WriteHeader(http.StatusServiceUnavailable) - _ = json.NewEncoder(w).Encode(map[string]string{"status": "not_ready"}) - return + var activeConns sync.WaitGroup + errAccept := proxy.AcceptLoop(listener, server, cfg, &activeConns, logger) + + if cfg.DrainTimeout > 0 { + logger.Info("draining active connections", "timeout", cfg.DrainTimeout) + drainCtx, drainCancel := context.WithTimeout(context.Background(), cfg.DrainTimeout) + defer drainCancel() + + done := make(chan struct{}) + go func() { + activeConns.Wait() + close(done) + }() + + select { + case <-done: + logger.Info("all active connections drained gracefully") + case <-drainCtx.Done(): + logger.Warn("drain timeout exceeded, forcing shutdown") } - w.WriteHeader(http.StatusOK) - _ = json.NewEncoder(w).Encode(map[string]string{"status": "ok"}) - }) - - mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - snapshot := Metrics{ - ActiveConnections: atomic.LoadInt64(&metrics.ActiveConnections), - TotalConnections: atomic.LoadInt64(&metrics.TotalConnections), - TotalBytesTx: atomic.LoadInt64(&metrics.TotalBytesTx), - TotalBytesRx: atomic.LoadInt64(&metrics.TotalBytesRx), - TotalErrors: atomic.LoadInt64(&metrics.TotalErrors), - RejectedConns: atomic.LoadInt64(&metrics.RejectedConns), - } - _ = json.NewEncoder(w).Encode(snapshot) - }) - - server := &http.Server{ - Addr: addr, - Handler: mux, - ReadHeaderTimeout: 10 * time.Second, } - go func() { - logger.Info("health server starting", "addr", addr) - if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { - logger.Error("health server error", "error", err) - } - }() + if err := server.Close(); err != nil { + logger.Error("error closing tsnet server", "error", err) + } - return server + return errAccept } -func printBanner(cfg Config) { +func printBanner(cfg config.Config) { fmt.Println() fmt.Println(" +---------------------------------------+") fmt.Printf(" | TAILSCALE BRIDGE %-14s |\n", version) @@ -628,172 +260,3 @@ func printBanner(cfg Config) { fmt.Println(" Waiting for connections...") fmt.Println() } - -func acceptLoop(listener net.Listener, dialer Dialer, cfg Config) error { - backoff := backoffMin - - for { - conn, err := listener.Accept() - if err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - - logger.Warn("accept error", "error", err, "backoff", backoff) - time.Sleep(backoff) - backoff = min(backoff*2, backoffMax) - continue - } - - // Reset backoff on successful accept - backoff = backoffMin - - // Check connection limit - current := atomic.LoadInt64(&metrics.ActiveConnections) - if current >= cfg.MaxConnections { - atomic.AddInt64(&metrics.RejectedConns, 1) - logger.Warn("connection rejected: limit reached", - "current", current, - "max", cfg.MaxConnections, - "client", conn.RemoteAddr()) - _ = conn.Close() - continue - } - - go handleConn(conn, dialer, cfg) - } -} - -// Buffer pool to reduce GC pressure. -var bufferPool = sync.Pool{ - New: func() any { - b := make([]byte, bufferSize) - return &b - }, -} - -func handleConn(client net.Conn, dialer Dialer, cfg Config) { - // Track metrics - atomic.AddInt64(&metrics.ActiveConnections, 1) - atomic.AddInt64(&metrics.TotalConnections, 1) - defer atomic.AddInt64(&metrics.ActiveConnections, -1) - - addr := client.RemoteAddr().String() - connStart := time.Now() - - if tcpConn, ok := client.(*net.TCPConn); ok { - if err := tcpConn.SetKeepAlive(true); err != nil { - logger.Debug("failed to set keepalive", "error", err) - } - if err := tcpConn.SetKeepAlivePeriod(keepAliveInterval); err != nil { - logger.Debug("failed to set keepalive period", "error", err) - } - } - - logger.Info("connection opened", "client", addr) - - ctx, cancel := context.WithTimeout(context.Background(), cfg.ConnectTimeout) - defer cancel() - - remote, err := dialer.Dial(ctx, "tcp", cfg.Target) - if err != nil { - atomic.AddInt64(&metrics.TotalErrors, 1) - logger.Error("dial failed", "client", addr, "target", cfg.Target, "error", err) - _ = client.Close() - return - } - - logger.Debug("tunnel established", "client", addr, "target", cfg.Target) - - bytesTx, bytesRx := proxyConnections(client, remote, addr) - - atomic.AddInt64(&metrics.TotalBytesTx, bytesTx) - atomic.AddInt64(&metrics.TotalBytesRx, bytesRx) - - duration := time.Since(connStart) - logger.Info("connection closed", - "client", addr, - "duration", duration, - "bytes_tx", bytesTx, - "bytes_rx", bytesRx) -} - -// proxyConnections performs bidirectional copy between client and remote, -// returning the bytes transferred in each direction. -func proxyConnections(client, remote net.Conn, addr string) (tx, rx int64) { - var once sync.Once - closeAll := func() { - once.Do(func() { - _ = client.Close() - _ = remote.Close() - }) - } - - copyConn := func(dst, src net.Conn, direction string, counter *int64) { - defer closeAll() - - bufPtr := bufferPool.Get().(*[]byte) - defer bufferPool.Put(bufPtr) - - n, err := io.CopyBuffer(dst, src, *bufPtr) - atomic.AddInt64(counter, n) - - if err != nil && !isExpectedCloseError(err) { - atomic.AddInt64(&metrics.TotalErrors, 1) - logger.Warn("copy error", - "client", addr, - "direction", direction, - "bytes", n, - "error", err) - } - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - copyConn(client, remote, "rx", &rx) - }() - copyConn(remote, client, "tx", &tx) - wg.Wait() - - return tx, rx -} - -// isExpectedCloseError returns true for errors that occur during normal connection close. -func isExpectedCloseError(err error) bool { - if err == nil { - return true - } - if errors.Is(err, net.ErrClosed) { - return true - } - if errors.Is(err, io.EOF) { - return true - } - // Check for common syscall errors during close - if errors.Is(err, syscall.ECONNRESET) { - return true - } - if errors.Is(err, syscall.EPIPE) { - return true - } - if errors.Is(err, syscall.ENOTCONN) { - return true - } - // Fallback for error messages (cross-platform compatibility) - errStr := strings.ToLower(err.Error()) - if strings.Contains(errStr, "use of closed network connection") { - return true - } - if strings.Contains(errStr, "connection reset by peer") { - return true - } - if strings.Contains(errStr, "forcibly closed by the remote host") { - return true - } - if strings.Contains(errStr, "closed pipe") { - return true - } - return false -} diff --git a/main_integration_test.go b/main_integration_test.go index 49aaa40..ff61be5 100644 --- a/main_integration_test.go +++ b/main_integration_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "log/slog" "net" "os" "runtime" @@ -12,6 +13,10 @@ import ( "sync/atomic" "testing" "time" + + "ts-bridge/internal/config" + "ts-bridge/internal/health" + "ts-bridge/internal/telemetry" ) // TestProxyBidirectionalFlow tests that data flows correctly in both directions. @@ -305,92 +310,35 @@ func TestConcurrentConnections(t *testing.T) { // TestConnectionLimit tests that connection limits are enforced. func TestConnectionLimit(t *testing.T) { - // Reset metrics for this test - oldMetrics := metrics - metrics = Metrics{} - defer func() { metrics = oldMetrics }() + telemetry.ResetMetrics() - cfg := Config{ + cfg := config.Config{ MaxConnections: 2, } - // Track active connections - var activeConns int64 - var rejectedConns int64 - // Simulate connection limit check for i := 0; i < 5; i++ { - current := atomic.LoadInt64(&activeConns) + current := telemetry.GetActiveConnections() if current >= cfg.MaxConnections { - atomic.AddInt64(&rejectedConns, 1) + telemetry.AddRejectedConn() continue } - atomic.AddInt64(&activeConns, 1) - } - - if rejectedConns != 3 { - t.Errorf("expected 3 rejected connections, got %d", rejectedConns) - } - if activeConns != 2 { - t.Errorf("expected 2 active connections, got %d", activeConns) - } -} - -// TestIsExpectedCloseError tests error classification. -func TestIsExpectedCloseError(t *testing.T) { - tests := []struct { - name string - err error - expected bool - }{ - {"nil error", nil, true}, - {"EOF", io.EOF, true}, - {"net.ErrClosed", net.ErrClosed, true}, - {"random error", errors.New("random error"), false}, - {"closed network", errors.New("use of closed network connection"), true}, - {"connection reset", errors.New("connection reset by peer"), true}, - {"windows wsarecv forced close", errors.New("wsarecv: An existing connection was forcibly closed by the remote host"), true}, - {"closed pipe", errors.New("io: read/write on closed pipe"), true}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := isExpectedCloseError(tt.err) - if result != tt.expected { - t.Errorf("isExpectedCloseError(%v) = %v, expected %v", tt.err, result, tt.expected) - } - }) - } -} - -// TestAcceptLoopBackoff tests exponential backoff behavior. -func TestAcceptLoopBackoff(t *testing.T) { - backoff := backoffMin - - // Simulate 5 consecutive failures - for i := 0; i < 5; i++ { - backoff = min(backoff*2, backoffMax) - } - - // After 5 doublings: 100ms -> 200ms -> 400ms -> 800ms -> 1600ms -> 3200ms - expected := 3200 * time.Millisecond - if backoff != expected { - t.Errorf("backoff after 5 failures = %v, expected %v", backoff, expected) + telemetry.AddActiveConnection(1) } - // Verify max cap - for i := 0; i < 10; i++ { - backoff = min(backoff*2, backoffMax) + m := telemetry.GetMetrics() + if m.RejectedConns != 3 { + t.Errorf("expected 3 rejected connections, got %d", m.RejectedConns) } - if backoff != backoffMax { - t.Errorf("backoff should cap at %v, got %v", backoffMax, backoff) + if m.ActiveConnections != 2 { + t.Errorf("expected 2 active connections, got %d", m.ActiveConnections) } } // TestEnsureStateDir tests state directory creation with permissions. func TestEnsureStateDir(t *testing.T) { // Initialize logger for test - initLogger(Config{LogFormat: "text"}) + initLogger(config.Config{LogFormat: "text"}) dir := t.TempDir() + "/test-state" @@ -423,7 +371,7 @@ func TestEnsureStateDir(t *testing.T) { // TestHealthEndpoints tests health server responses. func TestHealthEndpoints(t *testing.T) { // Initialize logger for test - initLogger(Config{LogFormat: "text"}) + l := slog.New(slog.NewTextHandler(io.Discard, nil)) // Find free port listener, err := net.Listen("tcp", "127.0.0.1:0") @@ -434,7 +382,7 @@ func TestHealthEndpoints(t *testing.T) { listener.Close() var ready atomic.Bool - server := startHealthServer(addr, &ready) + server := health.StartServer(addr, &ready, l) defer func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -537,10 +485,7 @@ func TestHealthEndpoints(t *testing.T) { // TestMetricsAtomicity tests that metrics updates are thread-safe. func TestMetricsAtomicity(t *testing.T) { - // Reset metrics - oldMetrics := metrics - metrics = Metrics{} - defer func() { metrics = oldMetrics }() + telemetry.ResetMetrics() const goroutines = 100 const iterations = 1000 @@ -551,8 +496,8 @@ func TestMetricsAtomicity(t *testing.T) { go func() { defer wg.Done() for j := 0; j < iterations; j++ { - atomic.AddInt64(&metrics.TotalConnections, 1) - atomic.AddInt64(&metrics.TotalBytesTx, 100) + telemetry.AddTotalConnection() + telemetry.AddBytesTx(100) } }() } @@ -560,11 +505,12 @@ func TestMetricsAtomicity(t *testing.T) { wg.Wait() expected := int64(goroutines * iterations) - if metrics.TotalConnections != expected { - t.Errorf("TotalConnections = %d, expected %d", metrics.TotalConnections, expected) + m := telemetry.GetMetrics() + if m.TotalConnections != expected { + t.Errorf("TotalConnections = %d, expected %d", m.TotalConnections, expected) } - if metrics.TotalBytesTx != expected*100 { - t.Errorf("TotalBytesTx = %d, expected %d", metrics.TotalBytesTx, expected*100) + if m.TotalBytesTx != expected*100 { + t.Errorf("TotalBytesTx = %d, expected %d", m.TotalBytesTx, expected*100) } } @@ -576,7 +522,7 @@ func TestVerboseConfig(t *testing.T) { defer os.Unsetenv("TS_AUTHKEY") // Test flag - cfg, err := loadConfig(true) + cfg, err := config.LoadConfig(true) if err != nil { t.Fatalf("loadConfig failed: %v", err) } @@ -588,7 +534,7 @@ func TestVerboseConfig(t *testing.T) { os.Setenv("TS_VERBOSE", "true") defer os.Unsetenv("TS_VERBOSE") - cfg, err = loadConfig(false) + cfg, err = config.LoadConfig(false) if err != nil { t.Fatalf("loadConfig failed: %v", err) } @@ -605,19 +551,19 @@ func TestMaxConnectionsConfig(t *testing.T) { defer os.Unsetenv("TS_AUTHKEY") // Test default - cfg, err := loadConfig(false) + cfg, err := config.LoadConfig(false) if err != nil { t.Fatalf("loadConfig failed: %v", err) } - if cfg.MaxConnections != defaultMaxConnections { - t.Errorf("expected default %d, got %d", defaultMaxConnections, cfg.MaxConnections) + if cfg.MaxConnections != 1000 { + t.Errorf("expected default %d, got %d", 1000, cfg.MaxConnections) } // Test custom os.Setenv("TS_MAX_CONNECTIONS", "500") defer os.Unsetenv("TS_MAX_CONNECTIONS") - cfg, err = loadConfig(false) + cfg, err = config.LoadConfig(false) if err != nil { t.Fatalf("loadConfig failed: %v", err) } @@ -627,7 +573,7 @@ func TestMaxConnectionsConfig(t *testing.T) { // Test invalid os.Setenv("TS_MAX_CONNECTIONS", "invalid") - _, err = loadConfig(false) + _, err = config.LoadConfig(false) if err == nil { t.Error("expected error for invalid max connections") } diff --git a/site/src/content/docs/getting-started.md b/site/src/content/docs/getting-started.md index 5a905ad..ff0b5e4 100644 --- a/site/src/content/docs/getting-started.md +++ b/site/src/content/docs/getting-started.md @@ -61,6 +61,7 @@ TS_CONTROL_URL=https://vpn.example.com | `TS_INSTANCE_NAME` | _(empty)_ | Stable instance alias for deterministic local port selection. | | `TS_PORT_RANGE` | `33389-34388` | Port range for auto mode (`START-END`). | | `TS_TIMEOUT` | `30s` | Timeout for Tailscale initialization and dial. Go duration format. | +| `TS_DRAIN_TIMEOUT` | `15s` | Timeout for graceful drain of active connections on shutdown. Go duration format. | | `TS_MAX_CONNECTIONS` | `1000` | Maximum concurrent connections before rejecting new ones. | | `TS_HEALTH_ADDR` | _(disabled)_ | Address for health/metrics HTTP server. | | `TS_VERBOSE` | `false` | Enable debug logging. Also available as `-v` flag. |