diff --git a/control/adapter.go b/control/adapter.go index 750c4ea..1126507 100644 --- a/control/adapter.go +++ b/control/adapter.go @@ -17,6 +17,8 @@ func (a *CoreAdapter) Connect(ctx context.Context, params ConnectParams) error { User: params.User, Port: params.Port, IdentityFile: params.IdentityFile, + Password: params.Password, + Passphrase: params.Passphrase, }) return err } diff --git a/control/control.go b/control/control.go index f9384dc..7224ff7 100644 --- a/control/control.go +++ b/control/control.go @@ -1,4 +1,4 @@ -// Package control provides a JSON-over-Unix-socket API for managing ShellGuard +// Package control provides a JSON-over-TCP API for managing ShellGuard // connections without going through the MCP/agent layer. package control @@ -32,6 +32,8 @@ type ConnectParams struct { User string `json:"user,omitempty"` Port int `json:"port,omitempty"` IdentityFile string `json:"identity_file,omitempty"` + Password string `json:"password,omitempty"` + Passphrase string `json:"passphrase,omitempty"` } // DisconnectParams are the parameters for the "disconnect" command. @@ -60,19 +62,18 @@ type Server struct { wg sync.WaitGroup } -// ListenAndServe starts the control socket server. It blocks until ctx is -// cancelled, then cleans up the socket file. -func ListenAndServe(ctx context.Context, socketPath string, handler Handler, logger *slog.Logger) error { - // Remove stale socket for idempotent restarts. - _ = os.Remove(socketPath) - - ln, err := net.Listen("unix", socketPath) +// ListenAndServe starts the control server on TCP localhost. It writes the +// resolved host:port to addrPath so clients can discover it. It blocks until +// ctx is cancelled, then cleans up the addr file. +func ListenAndServe(ctx context.Context, addrPath string, handler Handler, logger *slog.Logger) error { + ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return err } - if err := os.Chmod(socketPath, 0600); err != nil { + + resolvedAddr := ln.Addr().String() + if err := os.WriteFile(addrPath, []byte(resolvedAddr), 0600); err != nil { _ = ln.Close() - _ = os.Remove(socketPath) return err } @@ -88,7 +89,7 @@ func ListenAndServe(ctx context.Context, socketPath string, handler Handler, log _ = ln.Close() }() - logger.Info("control socket listening", "path", socketPath) + logger.Info("control server listening", "addr", resolvedAddr, "addrFile", addrPath) for { conn, err := ln.Accept() @@ -97,7 +98,7 @@ func ListenAndServe(ctx context.Context, socketPath string, handler Handler, log if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { break } - logger.Warn("control socket accept error", "error", err) + logger.Warn("control server accept error", "error", err) continue } s.wg.Add(1) @@ -105,8 +106,8 @@ func ListenAndServe(ctx context.Context, socketPath string, handler Handler, log } s.wg.Wait() - _ = os.Remove(socketPath) - logger.Info("control socket stopped") + _ = os.Remove(addrPath) + logger.Info("control server stopped") return nil } diff --git a/server/server.go b/server/server.go index 57fa249..65a08bb 100644 --- a/server/server.go +++ b/server/server.go @@ -74,6 +74,8 @@ type ConnectInput struct { User string `json:"user,omitempty" jsonschema:"SSH username (default root)"` Port int `json:"port,omitempty" jsonschema:"SSH port (default 22)"` IdentityFile string `json:"identity_file,omitempty" jsonschema:"Path to SSH identity file"` + Password string `json:"password,omitempty" jsonschema:"SSH password"` + Passphrase string `json:"passphrase,omitempty" jsonschema:"Passphrase for encrypted key"` } type ExecuteInput struct { @@ -185,6 +187,8 @@ func (c *Core) Connect(ctx context.Context, in ConnectInput) (map[string]any, er User: in.User, Port: in.Port, IdentityFile: in.IdentityFile, + Password: in.Password, + Passphrase: in.Passphrase, } if err := c.Runner.Connect(ctx, params); err != nil { c.logger.InfoContext(ctx, "connect", diff --git a/shellguard.go b/shellguard.go index 3d97ec9..d2415ef 100644 --- a/shellguard.go +++ b/shellguard.go @@ -112,9 +112,10 @@ func New(cfg Config) (*server.Core, error) { // RunStdio creates a server from cfg and runs it over stdin/stdout. // All SSH connections are closed when RunStdio returns. // -// If the SHELLGUARD_CONTROL_SOCKET environment variable is set, a control -// socket server is started alongside the MCP server. Failures to start the -// control socket are non-fatal. +// If the SHELLGUARD_CONTROL_ADDR environment variable is set (path to an +// .addr file), a TCP control server is started alongside the MCP server. +// The resolved host:port is written to the file. Failures to start the +// control server are non-fatal. func RunStdio(ctx context.Context, cfg Config) error { core, err := New(cfg) if err != nil { @@ -122,14 +123,14 @@ func RunStdio(ctx context.Context, cfg Config) error { } defer func() { _ = core.Close(ctx) }() - if socketPath := os.Getenv("SHELLGUARD_CONTROL_SOCKET"); socketPath != "" { + if addrPath := os.Getenv("SHELLGUARD_CONTROL_ADDR"); addrPath != "" { logger := cfg.Logger if logger == nil { logger = slog.Default() } go func() { - if err := control.ListenAndServe(ctx, socketPath, &control.CoreAdapter{Core: core}, logger); err != nil { - logger.Warn("control socket failed", "error", err) + if err := control.ListenAndServe(ctx, addrPath, &control.CoreAdapter{Core: core}, logger); err != nil { + logger.Warn("control server failed", "error", err) } }() } diff --git a/ssh/auth.go b/ssh/auth.go index e4e1ea1..2397e11 100644 --- a/ssh/auth.go +++ b/ssh/auth.go @@ -29,15 +29,24 @@ func defaultKeyPaths() []string { } // loadPrivateKey attempts to load and parse a private key from the given path. -// Returns nil if the file doesn't exist, can't be read, or can't be parsed -// (including passphrase-protected keys). All failures are silent by design. -func loadPrivateKey(path string) gossh.Signer { +// If the key is passphrase-protected and passphrase is non-empty, it attempts +// decryption with the passphrase. Returns nil if the file doesn't exist, can't +// be read, or can't be parsed. All failures are silent by design. +func loadPrivateKey(path string, passphrase string) gossh.Signer { key, err := os.ReadFile(path) if err != nil { return nil } signer, err := gossh.ParsePrivateKey(key) if err != nil { + var ppErr *gossh.PassphraseMissingError + if errors.As(err, &ppErr) && passphrase != "" { + signer, err = gossh.ParsePrivateKeyWithPassphrase(key, []byte(passphrase)) + if err != nil { + return nil + } + return signer + } return nil } return signer @@ -124,12 +133,20 @@ func buildAuthMethodsWithDefaults(params ConnectionParams, defaults []string) ([ if err != nil { var ppErr *gossh.PassphraseMissingError if errors.As(err, &ppErr) { - return nil, agentCleanup, fmt.Errorf( - "key %s is passphrase-protected; add it to your ssh-agent with: ssh-add %s", - params.IdentityFile, params.IdentityFile, - ) + if params.Passphrase != "" { + signer, err = gossh.ParsePrivateKeyWithPassphrase(key, []byte(params.Passphrase)) + if err != nil { + return nil, agentCleanup, fmt.Errorf("decrypt identity key with passphrase: %w", err) + } + } else { + return nil, agentCleanup, fmt.Errorf( + "key %s is passphrase-protected; provide a passphrase", + params.IdentityFile, + ) + } + } else { + return nil, agentCleanup, fmt.Errorf("parse identity key: %w", err) } - return nil, agentCleanup, fmt.Errorf("parse identity key: %w", err) } methods = append(methods, gossh.PublicKeys(signer)) } @@ -151,10 +168,15 @@ func buildAuthMethodsWithDefaults(params ConnectionParams, defaults []string) ([ } tried[normPath] = struct{}{} - if signer := loadPrivateKey(path); signer != nil { + if signer := loadPrivateKey(path, params.Passphrase); signer != nil { methods = append(methods, gossh.PublicKeys(signer)) } } + // Priority 4: Password auth. + if params.Password != "" { + methods = append(methods, gossh.Password(params.Password)) + } + return methods, agentCleanup, nil } diff --git a/ssh/auth_test.go b/ssh/auth_test.go index 7574e9a..747a8db 100644 --- a/ssh/auth_test.go +++ b/ssh/auth_test.go @@ -159,7 +159,7 @@ func TestDefaultKeyPathsExpectedOrder(t *testing.T) { // --- loadPrivateKey tests --- func TestLoadPrivateKeyMissingFile(t *testing.T) { - signer := loadPrivateKey("/nonexistent/path/id_ed25519") + signer := loadPrivateKey("/nonexistent/path/id_ed25519", "") if signer != nil { t.Fatal("expected nil signer for missing file") } @@ -168,7 +168,7 @@ func TestLoadPrivateKeyMissingFile(t *testing.T) { func TestLoadPrivateKeyInvalidContent(t *testing.T) { dir := t.TempDir() path := writeTestKey(t, dir, "bad_key", []byte("not a valid key")) - signer := loadPrivateKey(path) + signer := loadPrivateKey(path, "") if signer != nil { t.Fatal("expected nil signer for invalid key content") } @@ -177,7 +177,7 @@ func TestLoadPrivateKeyInvalidContent(t *testing.T) { func TestLoadPrivateKeyValidKey(t *testing.T) { dir := t.TempDir() path := writeTestKey(t, dir, "id_ed25519", generateTestKeyPEM(t)) - signer := loadPrivateKey(path) + signer := loadPrivateKey(path, "") if signer == nil { t.Fatal("expected non-nil signer for valid key") } @@ -409,11 +409,101 @@ func TestBuildAuthMethodsExplicitPassphraseProtected(t *testing.T) { t.Fatal("expected error for passphrase-protected explicit identity file") } errMsg := err.Error() - if !strings.Contains(errMsg, "ssh-add") { - t.Errorf("error = %q, want it to contain 'ssh-add'", errMsg) + if !strings.Contains(errMsg, "passphrase-protected") { + t.Errorf("error = %q, want it to contain 'passphrase-protected'", errMsg) } - if !strings.Contains(errMsg, "ssh-agent") { - t.Errorf("error = %q, want it to contain 'ssh-agent'", errMsg) + if !strings.Contains(errMsg, "provide a passphrase") { + t.Errorf("error = %q, want it to contain 'provide a passphrase'", errMsg) + } +} + +func TestBuildAuthMethodsExplicitPassphraseDecrypts(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + dir := t.TempDir() + keyPath := writeTestKey(t, dir, "id_ed25519_enc", generatePassphraseProtectedKeyPEM(t)) + + methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{ + Host: "example.com", + IdentityFile: keyPath, + Passphrase: "test-passphrase", + }, nil) + defer cleanup() + if err != nil { + t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err) + } + if len(methods) == 0 { + t.Fatal("expected at least one auth method after passphrase decryption") + } +} + +func TestBuildAuthMethodsExplicitPassphraseWrong(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + dir := t.TempDir() + keyPath := writeTestKey(t, dir, "id_ed25519_enc", generatePassphraseProtectedKeyPEM(t)) + + _, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{ + Host: "example.com", + IdentityFile: keyPath, + Passphrase: "wrong-passphrase", + }, nil) + defer cleanup() + if err == nil { + t.Fatal("expected error for wrong passphrase") + } + if !strings.Contains(err.Error(), "decrypt identity key") { + t.Errorf("error = %q, want it to contain 'decrypt identity key'", err.Error()) + } +} + +func TestBuildAuthMethodsPasswordAuth(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{ + Host: "example.com", + Password: "secret", + }, nil) + defer cleanup() + if err != nil { + t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err) + } + if len(methods) != 1 { + t.Errorf("expected 1 auth method (password), got %d", len(methods)) + } +} + +func TestBuildAuthMethodsKeyPlusPassword(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + dir := t.TempDir() + keyPath := writeTestKey(t, dir, "explicit_key", generateTestKeyPEM(t)) + + methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{ + Host: "example.com", + IdentityFile: keyPath, + Password: "secret", + }, nil) + defer cleanup() + if err != nil { + t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err) + } + if len(methods) != 2 { + t.Errorf("expected 2 auth methods (key + password), got %d", len(methods)) + } +} + +func TestBuildAuthMethodsDefaultPassphraseDecrypts(t *testing.T) { + t.Setenv("SSH_AUTH_SOCK", "") + dir := t.TempDir() + encPath := writeTestKey(t, dir, "id_ed25519", generatePassphraseProtectedKeyPEM(t)) + + methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{ + Host: "example.com", + Passphrase: "test-passphrase", + }, []string{encPath}) + defer cleanup() + if err != nil { + t.Fatalf("buildAuthMethodsWithDefaults() error = %v, want nil", err) + } + if len(methods) != 1 { + t.Errorf("expected 1 auth method for passphrase-decrypted default, got %d", len(methods)) } } diff --git a/ssh/ssh.go b/ssh/ssh.go index 5b703ad..05cca2b 100644 --- a/ssh/ssh.go +++ b/ssh/ssh.go @@ -48,6 +48,8 @@ type ConnectionParams struct { User string Port int IdentityFile string + Password string // SSH password auth + Passphrase string // decrypt passphrase-protected key } type ManagedConnection struct {