From 7d1ccbcdb981d99faa6d7a9023c4f3b9ce0bfd4d Mon Sep 17 00:00:00 2001 From: Jonathan Chun Date: Fri, 13 Feb 2026 02:06:43 -0500 Subject: [PATCH 1/4] Add secrets protection design doc Covers local heuristic-based secrets protection with two phases: pre-execution path checking to block access to sensitive files, and post-execution output scrubbing to redact known secret patterns from stdout/stderr before returning to the LLM. --- .../2026-02-13-secrets-protection-design.md | 266 ++++++++++++++++++ 1 file changed, 266 insertions(+) create mode 100644 docs/plans/2026-02-13-secrets-protection-design.md diff --git a/docs/plans/2026-02-13-secrets-protection-design.md b/docs/plans/2026-02-13-secrets-protection-design.md new file mode 100644 index 0000000..9cd9628 --- /dev/null +++ b/docs/plans/2026-02-13-secrets-protection-design.md @@ -0,0 +1,266 @@ +# Secrets Protection Design + +## Problem + +Shellguard currently has zero protection against reading sensitive files or +leaking secrets. Any allowed command (`cat`, `grep`, `head`, `printenv`, etc.) +can freely access `.env`, private keys, cloud credentials, and similar files. +The `download_file` tool has no path validation at all. + +### Threat Model + +1. **Malicious exfiltration** — prompt injection or model misbehavior causes the + LLM to issue tool calls that read secrets (e.g., `cat .env`, `printenv`) +2. **Accidental exposure** — the LLM innocently reads secrets during + debugging/exploration and includes them in responses, exposing them in + logs/chat history +3. **Privacy** — users don't want secrets sent to LLM API endpoints at all + +All three concerns demand the same solution: **local heuristics that block or +redact secrets before they ever leave the machine.** No network calls, no LLM +involvement. + +## Architecture + +A new standalone `secrets` package (zero internal dependencies) provides two +capabilities integrated into the existing pipeline: + +``` +User Input + │ + ▼ + Parse (parser) + │ + ▼ + Validate (validator) ← existing manifest-based validation + │ + ▼ + CheckSecrets (secrets) ← NEW: rejects sensitive file access + │ + ▼ + Reconstruct (ssh) + │ + ▼ + Execute (ssh) + │ + ▼ + Truncate (output) + │ + ▼ + ScrubSecrets (secrets) ← NEW: redacts secrets from output + │ + ▼ + Return to LLM +``` + +Both stages are function fields on `Core` for test injection, following the +existing pattern (`Parse`, `Validate`, `Reconstruct`, `Truncate`). + +## Phase 1: Pre-execution Path Checking + +### Sensitive Path Patterns + +Default patterns are hardcoded in the package. Users can override via config. + +| Category | Patterns | +| ----------------- | ----------------------------------------------------------------------------------------- | +| Env files | `.env`, `.env.*` (`.env.local`, `.env.production`, etc.) | +| SSH keys | `.ssh/id_*`, `.ssh/authorized_keys`, `.ssh/known_hosts` | +| TLS/Certs | `*.pem`, `*.key`, `*.pfx`, `*.p12` | +| Cloud credentials | `.aws/credentials`, `.aws/config`, `.gcloud/credentials.db`, `.azure/`, `.config/gcloud/` | +| K8s/Docker | `.kube/config`, `.docker/config.json` | +| App credentials | `credentials.json`, `service-account*.json`, `.netrc`, `.pgpass`, `.my.cnf` | +| Git | `.git-credentials`, `.gitconfig` | +| System | `/etc/shadow`, `/etc/gshadow`, `/etc/master.passwd` | +| Generic | `*secret*`, `*credential*`, `*token*` in filenames | + +### How It Works + +A new function `secrets.CheckPipeline(pipeline, config)` is called from +`Core.Execute()` between `Validate` and `Reconstruct`. For each segment's args: + +1. Normalize the path (`path.Clean`, resolve `~`, strip trailing slashes) +2. Check the **basename** against filename patterns (`.env`, `id_rsa`, etc.) +3. Check the **full path** against directory patterns (`.ssh/`, `.aws/`, etc.) +4. If a match is found, return a `SecretsError` with the specific pattern matched + +### Special Command Handling + +- **`printenv` with no args** — blocked (dumps all env vars including secrets) +- **`printenv VAR_NAME`** — blocked if `VAR_NAME` matches sensitive env var + patterns: `*KEY*`, `*SECRET*`, `*TOKEN*`, `*PASSWORD*`, `*CREDENTIAL*`, `*AUTH*` +- **`download_file`** — `remotePath` checked against the same pattern set in + `Core.DownloadFile()` + +### Configuration + +```go +type SecretsConfig struct { + // AllowedPaths overrides default blocking for specific paths. + // e.g., [".env.example", "/app/config/credentials.json"] + AllowedPaths []string `yaml:"allowed_paths"` + + // AdditionalPatterns adds more patterns to the default set. + AdditionalPatterns []string `yaml:"additional_patterns"` + + // DisablePathCheck disables pre-execution path checking entirely. + DisablePathCheck bool `yaml:"disable_path_check"` + + // DisableOutputScrub disables post-execution output scrubbing. + DisableOutputScrub bool `yaml:"disable_output_scrub"` +} +``` + +Hard block by default. Users can allowlist specific paths if they explicitly +choose to allow access. + +## Phase 2: Post-execution Output Scrubbing + +Catches secrets that appear in command output — e.g., `grep -r "database" +/app/config/` might return lines containing connection strings with embedded +passwords. + +### Patterns (High Confidence — Low False Positive Risk) + +| Pattern | Example | Redacted to | +| ------------------ | ------------------------------------------------- | -------------------------------------- | +| AWS Access Key | `AKIA1234567890ABCDEF` | `AKIA***REDACTED***` | +| AWS Secret Key | 40-char base64 after `aws_secret_access_key` | `***REDACTED***` | +| GitHub tokens | `ghp_xxxx`, `gho_xxxx`, `ghs_xxxx`, `github_pat_` | `ghp_***REDACTED***` | +| Stripe/OpenAI keys | `sk-xxxx`, `sk_live_xxxx`, `pk_live_xxxx` | `sk-***REDACTED***` | +| Private key blocks | `-----BEGIN (RSA\|EC\|OPENSSH) PRIVATE KEY-----` | `***REDACTED_PRIVATE_KEY***` | +| Bearer tokens | `Authorization: Bearer xxxx` | `Authorization: Bearer ***REDACTED***` | + +### Patterns (Medium Confidence — Tightened to Reduce False Positives) + +| Pattern | Tightening Heuristic | +| ---------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | +| JWTs (`eyJ...`) | Require all 3 dot-separated segments to be valid base64 AND total length > 30 chars | +| Connection strings (`postgresql://user:pass@host`) | Only redact the password portion between `://user:` and `@` | +| Generic key-value (`password=`, `secret=`, `token=`) | Require value portion to look like a secret: min length, mixed case/digits, not a common word. `password_policy=strict` should NOT match. | + +### Implementation + +`secrets.ScrubOutput(text string) string` applies all compiled regex patterns. +Patterns are compiled once at package init. Output is already capped at 64KB by +the truncation stage, so running ~10-15 compiled regexes is sub-millisecond. + +## Integration Points (Changes to Existing Code) + +### `shellguard.go` + +Add `SecretsConfig` to `Config` struct. Pass it through to `Core`. + +### `server/server.go` + +Add two function fields to `Core`: + +```go +type Core struct { + // ... existing fields ... + CheckSecrets func(*parser.Pipeline) error // default: secrets.CheckPipeline + ScrubSecrets func(string) string // default: secrets.ScrubOutput +} +``` + +Update `Core.Execute()`: + +```go +func (c *Core) Execute(ctx context.Context, in ExecuteInput) (output.CommandResult, error) { + pipeline, err := c.Parse(in.Command) + // ... + err = c.Validate(pipeline, c.Registry) + // ... + err = c.CheckSecrets(pipeline) // NEW + // ... + cmd := c.Reconstruct(pipeline, ...) + result := c.Runner.Execute(ctx, host, cmd, timeout) + truncated := c.Truncate(...) + // Scrub both stdout and stderr // NEW + truncated.Stdout = c.ScrubSecrets(truncated.Stdout) + truncated.Stderr = c.ScrubSecrets(truncated.Stderr) + return truncated, nil +} +``` + +Update `Core.DownloadFile()` to check `remotePath` against secrets patterns. + +### No changes to `validator/` or `output/` + +Secrets checking is a separate stage, keeping clean separation of concerns. + +## Package Layout + +``` +secrets/ + secrets.go # Core types, SecretsConfig, constructor, SecretsError + paths.go # Sensitive path patterns, CheckPipeline(), CheckPath() + scrub.go # Output scrubbing patterns, ScrubOutput() + paths_test.go # Path checking unit tests + scrub_test.go # Output scrubbing unit tests + false-positive regression + security_test.go # Attack vector tests (bypass attempts) + fuzz_test.go # Fuzz tests for both path checking and scrubbing +``` + +## Testing Strategy + +Following existing conventions: `testing` only, no testify, `got`/`want` +assertions, `t.Helper()` in helpers. + +### Path Tests (`paths_test.go`) + +Table-driven tests covering all pattern categories: + +- **Blocked:** `cat .env`, `head .ssh/id_rsa`, `grep -r foo credentials.json`, + `printenv`, `printenv AWS_SECRET_KEY` +- **Allowed:** `cat README.md`, `head main.go`, `printenv PATH`, `printenv HOME` +- **Allowlist override:** `.env.example` in allowed list → `cat .env.example` passes +- **Path normalization:** `../../.env`, `./foo/../.env`, `~/.ssh/id_rsa` + +### Scrub Tests (`scrub_test.go`) + +Table-driven with input/expected output pairs: + +- **Positive:** Each pattern category with realistic examples +- **False-positive regression:** `token_count=5`, `password_policy=strict`, + `secret_garden.txt`, base64 strings that start with `eyJ` but aren't JWTs + +### Security Tests (`security_test.go`) + +`TestSec_` prefix. Bypass attempts: + +- Path traversal: `../../.env`, `/app/../../../etc/shadow` +- Encoding tricks: URL-encoded paths, unicode homoglyphs +- Indirect access: `find / -name .env`, `grep -rl password /etc/` +- Argument hiding: flags that take path values (`grep -f .env foo.txt`) + +### Fuzz Tests (`fuzz_test.go`) + +- `FuzzCheckPath` — random strings never panic, always return valid error or nil +- `FuzzScrubOutput` — random input never panics, output length ≤ input length + + redaction marker overhead + +### Cross-layer Tests (`security_pipeline_test.go`) + +Add cases to existing test file: + +- `cat .env` → rejected +- `cat README.md` → allowed +- `printenv` (bare) → rejected +- `printenv PATH` → allowed +- `head .aws/credentials` → rejected + +## Open Questions + +1. **`grep -f .env`** — Should we check flag values that take file paths? + Manifests know which flags have `takes_value: true`, but we'd need to know + which of those values are file paths. Could inspect based on the flag name + (e.g., `-f`, `--file`, `--include-from`). Start simple, iterate. + +2. **Symlinks** — We can't resolve symlinks before execution (the file is on the + remote host). The heuristic only checks the path string as written. This is + an accepted limitation of the local-heuristic approach. + +3. **`find` output** — `find / -name .env` doesn't read secrets, it lists paths. + Should we block `find` commands that search for sensitive filenames? Leaning + toward yes for `-name .env` patterns but this adds complexity. From 064c73cd4e9581fbfee8235d3ce7816a62b0776d Mon Sep 17 00:00:00 2001 From: Jonathan Chun Date: Tue, 17 Feb 2026 18:04:26 -0800 Subject: [PATCH 2/4] test: fix unix socket path length issue on macOS --- ssh/auth_test.go | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ssh/auth_test.go b/ssh/auth_test.go index ecd6e3e..913afb5 100644 --- a/ssh/auth_test.go +++ b/ssh/auth_test.go @@ -499,7 +499,14 @@ func startTestAgentWithKey(t *testing.T) string { // startTestAgentKeyring starts an ssh-agent serving the given keyring. func startTestAgentKeyring(t *testing.T, keyring agent.Agent) string { t.Helper() - sockPath := filepath.Join(t.TempDir(), "agent.sock") + // Use /tmp directly to avoid long paths on macOS causing "bind: invalid argument" + dir, err := os.MkdirTemp("/tmp", "sg-test-") + if err != nil { + t.Fatalf("create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(dir) }) + + sockPath := filepath.Join(dir, "agent.sock") ln, err := net.Listen("unix", sockPath) if err != nil { t.Fatalf("listen: %v", err) From b3036e6b030c2b3c7a18a64466df8b3602f526e4 Mon Sep 17 00:00:00 2001 From: Jonathan Chun Date: Tue, 17 Feb 2026 18:09:14 -0800 Subject: [PATCH 3/4] test: use os.MkdirTemp for shorter socket paths, fix staticcheck nil warnings --- integration_shellguard_test.go | 1 + manifest/loaddir_test.go | 1 + ssh/auth_test.go | 5 +++-- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/integration_shellguard_test.go b/integration_shellguard_test.go index a1f057b..ebbea41 100644 --- a/integration_shellguard_test.go +++ b/integration_shellguard_test.go @@ -138,6 +138,7 @@ func TestIntegrationToolNames(t *testing.T) { m := registry[tool] if m == nil { t.Fatalf("missing manifest for toolkit tool %q", tool) + return } if m.Deny { t.Fatalf("tool %q must be allowed, but manifest is deny=true", tool) diff --git a/manifest/loaddir_test.go b/manifest/loaddir_test.go index 628427a..ee4d0cc 100644 --- a/manifest/loaddir_test.go +++ b/manifest/loaddir_test.go @@ -51,6 +51,7 @@ allows_path_args: true f := m.GetFlag("-v") if f == nil { t.Fatal("GetFlag(\"-v\") = nil") + return } if got, want := f.Description, "verbose output"; got != want { t.Fatalf("flag Description = %q, want %q", got, want) diff --git a/ssh/auth_test.go b/ssh/auth_test.go index 913afb5..7574e9a 100644 --- a/ssh/auth_test.go +++ b/ssh/auth_test.go @@ -499,8 +499,9 @@ func startTestAgentWithKey(t *testing.T) string { // startTestAgentKeyring starts an ssh-agent serving the given keyring. func startTestAgentKeyring(t *testing.T, keyring agent.Agent) string { t.Helper() - // Use /tmp directly to avoid long paths on macOS causing "bind: invalid argument" - dir, err := os.MkdirTemp("/tmp", "sg-test-") + // Use os.MkdirTemp with default temp dir to avoid t.TempDir()'s long paths + // which can exceed the 103-byte limit for unix sockets on macOS. + dir, err := os.MkdirTemp("", "sg-agent-") if err != nil { t.Fatalf("create temp dir: %v", err) } From 7650df07a2c63bfdcba4745fb2d2f2678b4522bc Mon Sep 17 00:00:00 2001 From: Jonathan Chun Date: Sat, 28 Feb 2026 11:02:29 -0800 Subject: [PATCH 4/4] feat(control): add unix control socket API and config wiring --- config/config.go | 43 ++++++++-- control/adapter.go | 33 ++++++++ control/control.go | 184 ++++++++++++++++++++++++++++++++++++++++++ server/server.go | 10 ++- server/server_test.go | 2 +- shellguard.go | 19 +++++ 6 files changed, 280 insertions(+), 11 deletions(-) create mode 100644 control/adapter.go create mode 100644 control/control.go diff --git a/config/config.go b/config/config.go index b4585f7..8d468ce 100644 --- a/config/config.go +++ b/config/config.go @@ -40,15 +40,25 @@ func (d *duration) Duration() time.Duration { return d.d } +// AutoConnectConfig holds auto-connect parameters set via environment variables. +// When Host is non-empty, ShellGuard connects during MCP initialization. +type AutoConnectConfig struct { + Host string + User string + Port int + IdentityFile string +} + // Config for ShellGuard. Pointer fields; nil = unset. type Config struct { - Timeout *int `yaml:"timeout"` - MaxOutputBytes *int `yaml:"max_output_bytes"` - MaxDownloadBytes *int `yaml:"max_download_bytes"` - DownloadDir *string `yaml:"download_dir"` - MaxSleepSeconds *int `yaml:"max_sleep_seconds"` - SSH *SSHConfig `yaml:"ssh"` - ManifestDir *string `yaml:"manifest_dir"` + Timeout *int `yaml:"timeout"` + MaxOutputBytes *int `yaml:"max_output_bytes"` + MaxDownloadBytes *int `yaml:"max_download_bytes"` + DownloadDir *string `yaml:"download_dir"` + MaxSleepSeconds *int `yaml:"max_sleep_seconds"` + SSH *SSHConfig `yaml:"ssh"` + ManifestDir *string `yaml:"manifest_dir"` + AutoConnect *AutoConnectConfig `yaml:"-"` // env-only, not from config file } // SSHConfig holds SSH-specific configuration. @@ -180,6 +190,25 @@ func (c *Config) applyEnvOverrides() error { c.SSH.KnownHostsFile = &v } + // Auto-connect env vars (presence of SHELLGUARD_HOST triggers auto-connect). + if host, ok := os.LookupEnv("SHELLGUARD_HOST"); ok && host != "" { + ac := &AutoConnectConfig{Host: host} + if v, ok := os.LookupEnv("SHELLGUARD_USER"); ok { + ac.User = v + } + if v, ok := os.LookupEnv("SHELLGUARD_PORT"); ok { + n, err := strconv.Atoi(v) + if err != nil { + return fmt.Errorf("parse SHELLGUARD_PORT: %w", err) + } + ac.Port = n + } + if v, ok := os.LookupEnv("SHELLGUARD_IDENTITY_FILE"); ok { + ac.IdentityFile = v + } + c.AutoConnect = ac + } + return nil } diff --git a/control/adapter.go b/control/adapter.go new file mode 100644 index 0000000..750c4ea --- /dev/null +++ b/control/adapter.go @@ -0,0 +1,33 @@ +package control + +import ( + "context" + + "github.com/jonchun/shellguard/server" +) + +// CoreAdapter implements Handler by delegating to a server.Core instance. +type CoreAdapter struct { + Core *server.Core +} + +func (a *CoreAdapter) Connect(ctx context.Context, params ConnectParams) error { + _, err := a.Core.Connect(ctx, server.ConnectInput{ + Host: params.Host, + User: params.User, + Port: params.Port, + IdentityFile: params.IdentityFile, + }) + return err +} + +func (a *CoreAdapter) Disconnect(ctx context.Context, params DisconnectParams) error { + _, err := a.Core.Disconnect(ctx, server.DisconnectInput{ + Host: params.Host, + }) + return err +} + +func (a *CoreAdapter) ConnectedHosts() []string { + return a.Core.ConnectedHostsSnapshot() +} diff --git a/control/control.go b/control/control.go new file mode 100644 index 0000000..f9384dc --- /dev/null +++ b/control/control.go @@ -0,0 +1,184 @@ +// Package control provides a JSON-over-Unix-socket API for managing ShellGuard +// connections without going through the MCP/agent layer. +package control + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "log/slog" + "net" + "os" + "sync" +) + +// Request is the envelope sent by a client over the control socket. +type Request struct { + Command string `json:"command"` + Params json.RawMessage `json:"params,omitempty"` +} + +// Response is the envelope sent back to the client. +type Response struct { + OK bool `json:"ok"` + Data json.RawMessage `json:"data,omitempty"` + Error string `json:"error,omitempty"` +} + +// ConnectParams are the parameters for the "connect" command. +type ConnectParams struct { + Host string `json:"host"` + User string `json:"user,omitempty"` + Port int `json:"port,omitempty"` + IdentityFile string `json:"identity_file,omitempty"` +} + +// DisconnectParams are the parameters for the "disconnect" command. +type DisconnectParams struct { + Host string `json:"host"` +} + +// StatusData is returned by the "status" command. +type StatusData struct { + ConnectedHosts []string `json:"connected_hosts"` +} + +// Handler is the interface that the control socket server dispatches to. +type Handler interface { + Connect(ctx context.Context, params ConnectParams) error + Disconnect(ctx context.Context, params DisconnectParams) error + ConnectedHosts() []string +} + +// Server listens on a Unix socket and dispatches JSON requests to a Handler. +type Server struct { + listener net.Listener + handler Handler + logger *slog.Logger + + wg sync.WaitGroup +} + +// ListenAndServe starts the control socket server. It blocks until ctx is +// cancelled, then cleans up the socket file. +func ListenAndServe(ctx context.Context, socketPath string, handler Handler, logger *slog.Logger) error { + // Remove stale socket for idempotent restarts. + _ = os.Remove(socketPath) + + ln, err := net.Listen("unix", socketPath) + if err != nil { + return err + } + if err := os.Chmod(socketPath, 0600); err != nil { + _ = ln.Close() + _ = os.Remove(socketPath) + return err + } + + s := &Server{ + listener: ln, + handler: handler, + logger: logger, + } + + // Shut down when context is cancelled. + go func() { + <-ctx.Done() + _ = ln.Close() + }() + + logger.Info("control socket listening", "path", socketPath) + + for { + conn, err := ln.Accept() + if err != nil { + // Expected when listener is closed during shutdown. + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { + break + } + logger.Warn("control socket accept error", "error", err) + continue + } + s.wg.Add(1) + go s.handleConn(ctx, conn) + } + + s.wg.Wait() + _ = os.Remove(socketPath) + logger.Info("control socket stopped") + return nil +} + +func (s *Server) handleConn(ctx context.Context, conn net.Conn) { + defer s.wg.Done() + defer func() { _ = conn.Close() }() + + scanner := bufio.NewScanner(conn) + // Allow up to 1 MB per line. + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + + for scanner.Scan() { + line := scanner.Bytes() + if len(line) == 0 { + continue + } + + var req Request + if err := json.Unmarshal(line, &req); err != nil { + s.writeResponse(conn, Response{Error: "invalid JSON: " + err.Error()}) + continue + } + + resp := s.dispatch(ctx, req) + s.writeResponse(conn, resp) + } +} + +func (s *Server) dispatch(ctx context.Context, req Request) Response { + switch req.Command { + case "connect": + var params ConnectParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return Response{Error: "invalid connect params: " + err.Error()} + } + if err := s.handler.Connect(ctx, params); err != nil { + return Response{Error: err.Error()} + } + data, _ := json.Marshal(map[string]string{ + "host": params.Host, + "message": "Connected to " + params.Host, + }) + return Response{OK: true, Data: data} + + case "disconnect": + var params DisconnectParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + return Response{Error: "invalid disconnect params: " + err.Error()} + } + if err := s.handler.Disconnect(ctx, params); err != nil { + return Response{Error: err.Error()} + } + return Response{OK: true} + + case "status": + hosts := s.handler.ConnectedHosts() + data, _ := json.Marshal(StatusData{ConnectedHosts: hosts}) + return Response{OK: true, Data: data} + + default: + return Response{Error: "unknown command: " + req.Command} + } +} + +func (s *Server) writeResponse(conn net.Conn, resp Response) { + line, err := json.Marshal(resp) + if err != nil { + s.logger.Error("control socket marshal error", "error", err) + return + } + line = append(line, '\n') + if _, err := conn.Write(line); err != nil { + s.logger.Debug("control socket write error", "error", err) + } +} diff --git a/server/server.go b/server/server.go index a7413e7..57fa249 100644 --- a/server/server.go +++ b/server/server.go @@ -515,7 +515,7 @@ func (c *Core) resolveHostForState(host string) string { if host != "" { return host } - hosts := c.connectedHostsSnapshot() + hosts := c.ConnectedHostsSnapshot() if len(hosts) == 1 { return hosts[0] } @@ -529,7 +529,7 @@ func (c *Core) resolveProvisionHost(host string) (string, error) { } return host, nil } - hosts := c.connectedHostsSnapshot() + hosts := c.ConnectedHostsSnapshot() switch len(hosts) { case 0: return "", errors.New("not connected") @@ -540,7 +540,8 @@ func (c *Core) resolveProvisionHost(host string) (string, error) { } } -func (c *Core) connectedHostsSnapshot() []string { +// ConnectedHostsSnapshot returns a sorted snapshot of currently connected hosts. +func (c *Core) ConnectedHostsSnapshot() []string { c.mu.RLock() defer c.mu.RUnlock() hosts := make([]string, 0, len(c.connectedHosts)) @@ -674,6 +675,9 @@ type ServerOptions struct { Name string // Version is the MCP server implementation version. Default: "0.2.0". Version string + // AutoConnect, when non-nil, causes an automatic SSH connection after + // the MCP handshake completes (via InitializedHandler). + AutoConnect *ConnectInput } func NewMCPServer(core *Core, opts ...ServerOptions) *mcp.Server { diff --git a/server/server_test.go b/server/server_test.go index 987b4af..d03a2a9 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1008,7 +1008,7 @@ func TestCloseDisconnectsAll(t *testing.T) { } // Verify internal Core state is cleared. - if hosts := core.connectedHostsSnapshot(); len(hosts) != 0 { + if hosts := core.ConnectedHostsSnapshot(); len(hosts) != 0 { t.Fatalf("expected 0 connectedHosts, got %v", hosts) } for _, host := range []string{"h1", "h2"} { diff --git a/shellguard.go b/shellguard.go index 3df6fa2..3d97ec9 100644 --- a/shellguard.go +++ b/shellguard.go @@ -5,8 +5,10 @@ import ( "context" "fmt" "log/slog" + "os" "github.com/jonchun/shellguard/config" + "github.com/jonchun/shellguard/control" "github.com/jonchun/shellguard/manifest" "github.com/jonchun/shellguard/server" "github.com/jonchun/shellguard/ssh" @@ -109,12 +111,29 @@ func New(cfg Config) (*server.Core, error) { // RunStdio creates a server from cfg and runs it over stdin/stdout. // All SSH connections are closed when RunStdio returns. +// +// If the SHELLGUARD_CONTROL_SOCKET environment variable is set, a control +// socket server is started alongside the MCP server. Failures to start the +// control socket are non-fatal. func RunStdio(ctx context.Context, cfg Config) error { core, err := New(cfg) if err != nil { return err } defer func() { _ = core.Close(ctx) }() + + if socketPath := os.Getenv("SHELLGUARD_CONTROL_SOCKET"); socketPath != "" { + logger := cfg.Logger + if logger == nil { + logger = slog.Default() + } + go func() { + if err := control.ListenAndServe(ctx, socketPath, &control.CoreAdapter{Core: core}, logger); err != nil { + logger.Warn("control socket failed", "error", err) + } + }() + } + return server.RunStdio(ctx, core, server.ServerOptions{ Name: cfg.Name, Version: cfg.Version,