Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions control/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ func (a *CoreAdapter) Connect(ctx context.Context, params ConnectParams) error {
User: params.User,
Port: params.Port,
IdentityFile: params.IdentityFile,
Password: params.Password,
Passphrase: params.Passphrase,
})
return err
}
Expand Down
29 changes: 15 additions & 14 deletions control/control.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Package control provides a JSON-over-Unix-socket API for managing ShellGuard
// Package control provides a JSON-over-TCP API for managing ShellGuard
// connections without going through the MCP/agent layer.
package control

Expand Down Expand Up @@ -32,6 +32,8 @@ type ConnectParams struct {
User string `json:"user,omitempty"`
Port int `json:"port,omitempty"`
IdentityFile string `json:"identity_file,omitempty"`
Password string `json:"password,omitempty"`
Passphrase string `json:"passphrase,omitempty"`
}

// DisconnectParams are the parameters for the "disconnect" command.
Expand Down Expand Up @@ -60,19 +62,18 @@ type Server struct {
wg sync.WaitGroup
}

// ListenAndServe starts the control socket server. It blocks until ctx is
// cancelled, then cleans up the socket file.
func ListenAndServe(ctx context.Context, socketPath string, handler Handler, logger *slog.Logger) error {
// Remove stale socket for idempotent restarts.
_ = os.Remove(socketPath)

ln, err := net.Listen("unix", socketPath)
// ListenAndServe starts the control server on TCP localhost. It writes the
// resolved host:port to addrPath so clients can discover it. It blocks until
// ctx is cancelled, then cleans up the addr file.
func ListenAndServe(ctx context.Context, addrPath string, handler Handler, logger *slog.Logger) error {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return err
}
if err := os.Chmod(socketPath, 0600); err != nil {

resolvedAddr := ln.Addr().String()
if err := os.WriteFile(addrPath, []byte(resolvedAddr), 0600); err != nil {
_ = ln.Close()
_ = os.Remove(socketPath)
return err
}

Expand All @@ -88,7 +89,7 @@ func ListenAndServe(ctx context.Context, socketPath string, handler Handler, log
_ = ln.Close()
}()

logger.Info("control socket listening", "path", socketPath)
logger.Info("control server listening", "addr", resolvedAddr, "addrFile", addrPath)

for {
conn, err := ln.Accept()
Expand All @@ -97,16 +98,16 @@ func ListenAndServe(ctx context.Context, socketPath string, handler Handler, log
if errors.Is(err, net.ErrClosed) || ctx.Err() != nil {
break
}
logger.Warn("control socket accept error", "error", err)
logger.Warn("control server accept error", "error", err)
continue
}
s.wg.Add(1)
go s.handleConn(ctx, conn)
}

s.wg.Wait()
_ = os.Remove(socketPath)
logger.Info("control socket stopped")
_ = os.Remove(addrPath)
logger.Info("control server stopped")
return nil
}

Expand Down
4 changes: 4 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ type ConnectInput struct {
User string `json:"user,omitempty" jsonschema:"SSH username (default root)"`
Port int `json:"port,omitempty" jsonschema:"SSH port (default 22)"`
IdentityFile string `json:"identity_file,omitempty" jsonschema:"Path to SSH identity file"`
Password string `json:"password,omitempty" jsonschema:"SSH password"`
Passphrase string `json:"passphrase,omitempty" jsonschema:"Passphrase for encrypted key"`
}

type ExecuteInput struct {
Expand Down Expand Up @@ -185,6 +187,8 @@ func (c *Core) Connect(ctx context.Context, in ConnectInput) (map[string]any, er
User: in.User,
Port: in.Port,
IdentityFile: in.IdentityFile,
Password: in.Password,
Passphrase: in.Passphrase,
}
if err := c.Runner.Connect(ctx, params); err != nil {
c.logger.InfoContext(ctx, "connect",
Expand Down
13 changes: 7 additions & 6 deletions shellguard.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,25 @@ func New(cfg Config) (*server.Core, error) {
// RunStdio creates a server from cfg and runs it over stdin/stdout.
// All SSH connections are closed when RunStdio returns.
//
// If the SHELLGUARD_CONTROL_SOCKET environment variable is set, a control
// socket server is started alongside the MCP server. Failures to start the
// control socket are non-fatal.
// If the SHELLGUARD_CONTROL_ADDR environment variable is set (path to an
// .addr file), a TCP control server is started alongside the MCP server.
// The resolved host:port is written to the file. Failures to start the
// control server are non-fatal.
func RunStdio(ctx context.Context, cfg Config) error {
core, err := New(cfg)
if err != nil {
return err
}
defer func() { _ = core.Close(ctx) }()

if socketPath := os.Getenv("SHELLGUARD_CONTROL_SOCKET"); socketPath != "" {
if addrPath := os.Getenv("SHELLGUARD_CONTROL_ADDR"); addrPath != "" {
logger := cfg.Logger
if logger == nil {
logger = slog.Default()
}
go func() {
if err := control.ListenAndServe(ctx, socketPath, &control.CoreAdapter{Core: core}, logger); err != nil {
logger.Warn("control socket failed", "error", err)
if err := control.ListenAndServe(ctx, addrPath, &control.CoreAdapter{Core: core}, logger); err != nil {
logger.Warn("control server failed", "error", err)
}
}()
}
Expand Down
40 changes: 31 additions & 9 deletions ssh/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,24 @@ func defaultKeyPaths() []string {
}

// loadPrivateKey attempts to load and parse a private key from the given path.
// Returns nil if the file doesn't exist, can't be read, or can't be parsed
// (including passphrase-protected keys). All failures are silent by design.
func loadPrivateKey(path string) gossh.Signer {
// If the key is passphrase-protected and passphrase is non-empty, it attempts
// decryption with the passphrase. Returns nil if the file doesn't exist, can't
// be read, or can't be parsed. All failures are silent by design.
func loadPrivateKey(path string, passphrase string) gossh.Signer {
key, err := os.ReadFile(path)
if err != nil {
return nil
}
signer, err := gossh.ParsePrivateKey(key)
if err != nil {
var ppErr *gossh.PassphraseMissingError
if errors.As(err, &ppErr) && passphrase != "" {
signer, err = gossh.ParsePrivateKeyWithPassphrase(key, []byte(passphrase))
if err != nil {
return nil
}
return signer
}
return nil
}
return signer
Expand Down Expand Up @@ -124,12 +133,20 @@ func buildAuthMethodsWithDefaults(params ConnectionParams, defaults []string) ([
if err != nil {
var ppErr *gossh.PassphraseMissingError
if errors.As(err, &ppErr) {
return nil, agentCleanup, fmt.Errorf(
"key %s is passphrase-protected; add it to your ssh-agent with: ssh-add %s",
params.IdentityFile, params.IdentityFile,
)
if params.Passphrase != "" {
signer, err = gossh.ParsePrivateKeyWithPassphrase(key, []byte(params.Passphrase))
if err != nil {
return nil, agentCleanup, fmt.Errorf("decrypt identity key with passphrase: %w", err)
}
} else {
return nil, agentCleanup, fmt.Errorf(
"key %s is passphrase-protected; provide a passphrase",
params.IdentityFile,
)
}
} else {
return nil, agentCleanup, fmt.Errorf("parse identity key: %w", err)
}
return nil, agentCleanup, fmt.Errorf("parse identity key: %w", err)
}
methods = append(methods, gossh.PublicKeys(signer))
}
Expand All @@ -151,10 +168,15 @@ func buildAuthMethodsWithDefaults(params ConnectionParams, defaults []string) ([
}
tried[normPath] = struct{}{}

if signer := loadPrivateKey(path); signer != nil {
if signer := loadPrivateKey(path, params.Passphrase); signer != nil {
methods = append(methods, gossh.PublicKeys(signer))
}
}

// Priority 4: Password auth.
if params.Password != "" {
methods = append(methods, gossh.Password(params.Password))
}

return methods, agentCleanup, nil
}
104 changes: 97 additions & 7 deletions ssh/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestDefaultKeyPathsExpectedOrder(t *testing.T) {
// --- loadPrivateKey tests ---

func TestLoadPrivateKeyMissingFile(t *testing.T) {
signer := loadPrivateKey("/nonexistent/path/id_ed25519")
signer := loadPrivateKey("/nonexistent/path/id_ed25519", "")
if signer != nil {
t.Fatal("expected nil signer for missing file")
}
Expand All @@ -168,7 +168,7 @@ func TestLoadPrivateKeyMissingFile(t *testing.T) {
func TestLoadPrivateKeyInvalidContent(t *testing.T) {
dir := t.TempDir()
path := writeTestKey(t, dir, "bad_key", []byte("not a valid key"))
signer := loadPrivateKey(path)
signer := loadPrivateKey(path, "")
if signer != nil {
t.Fatal("expected nil signer for invalid key content")
}
Expand All @@ -177,7 +177,7 @@ func TestLoadPrivateKeyInvalidContent(t *testing.T) {
func TestLoadPrivateKeyValidKey(t *testing.T) {
dir := t.TempDir()
path := writeTestKey(t, dir, "id_ed25519", generateTestKeyPEM(t))
signer := loadPrivateKey(path)
signer := loadPrivateKey(path, "")
if signer == nil {
t.Fatal("expected non-nil signer for valid key")
}
Expand Down Expand Up @@ -409,11 +409,101 @@ func TestBuildAuthMethodsExplicitPassphraseProtected(t *testing.T) {
t.Fatal("expected error for passphrase-protected explicit identity file")
}
errMsg := err.Error()
if !strings.Contains(errMsg, "ssh-add") {
t.Errorf("error = %q, want it to contain 'ssh-add'", errMsg)
if !strings.Contains(errMsg, "passphrase-protected") {
t.Errorf("error = %q, want it to contain 'passphrase-protected'", errMsg)
}
if !strings.Contains(errMsg, "ssh-agent") {
t.Errorf("error = %q, want it to contain 'ssh-agent'", errMsg)
if !strings.Contains(errMsg, "provide a passphrase") {
t.Errorf("error = %q, want it to contain 'provide a passphrase'", errMsg)
}
}

func TestBuildAuthMethodsExplicitPassphraseDecrypts(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
dir := t.TempDir()
keyPath := writeTestKey(t, dir, "id_ed25519_enc", generatePassphraseProtectedKeyPEM(t))

methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{
Host: "example.com",
IdentityFile: keyPath,
Passphrase: "test-passphrase",
}, nil)
defer cleanup()
if err != nil {
t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err)
}
if len(methods) == 0 {
t.Fatal("expected at least one auth method after passphrase decryption")
}
}

func TestBuildAuthMethodsExplicitPassphraseWrong(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
dir := t.TempDir()
keyPath := writeTestKey(t, dir, "id_ed25519_enc", generatePassphraseProtectedKeyPEM(t))

_, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{
Host: "example.com",
IdentityFile: keyPath,
Passphrase: "wrong-passphrase",
}, nil)
defer cleanup()
if err == nil {
t.Fatal("expected error for wrong passphrase")
}
if !strings.Contains(err.Error(), "decrypt identity key") {
t.Errorf("error = %q, want it to contain 'decrypt identity key'", err.Error())
}
}

func TestBuildAuthMethodsPasswordAuth(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{
Host: "example.com",
Password: "secret",
}, nil)
defer cleanup()
if err != nil {
t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err)
}
if len(methods) != 1 {
t.Errorf("expected 1 auth method (password), got %d", len(methods))
}
}

func TestBuildAuthMethodsKeyPlusPassword(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
dir := t.TempDir()
keyPath := writeTestKey(t, dir, "explicit_key", generateTestKeyPEM(t))

methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{
Host: "example.com",
IdentityFile: keyPath,
Password: "secret",
}, nil)
defer cleanup()
if err != nil {
t.Fatalf("buildAuthMethodsWithDefaults() error = %v", err)
}
if len(methods) != 2 {
t.Errorf("expected 2 auth methods (key + password), got %d", len(methods))
}
}

func TestBuildAuthMethodsDefaultPassphraseDecrypts(t *testing.T) {
t.Setenv("SSH_AUTH_SOCK", "")
dir := t.TempDir()
encPath := writeTestKey(t, dir, "id_ed25519", generatePassphraseProtectedKeyPEM(t))

methods, cleanup, err := buildAuthMethodsWithDefaults(ConnectionParams{
Host: "example.com",
Passphrase: "test-passphrase",
}, []string{encPath})
defer cleanup()
if err != nil {
t.Fatalf("buildAuthMethodsWithDefaults() error = %v, want nil", err)
}
if len(methods) != 1 {
t.Errorf("expected 1 auth method for passphrase-decrypted default, got %d", len(methods))
}
}

Expand Down
2 changes: 2 additions & 0 deletions ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ type ConnectionParams struct {
User string
Port int
IdentityFile string
Password string // SSH password auth
Passphrase string // decrypt passphrase-protected key
}

type ManagedConnection struct {
Expand Down