diff --git a/config/config.go b/config/config.go index dc862c5..f9b848f 100644 --- a/config/config.go +++ b/config/config.go @@ -48,6 +48,7 @@ type AutoConnectConfig struct { User string Port int IdentityFile string + Transport string } // Config for ShellGuard. Pointer fields; nil = unset. @@ -218,6 +219,9 @@ func (c *Config) applyEnvOverrides() error { if v, ok := os.LookupEnv("SHELLGUARD_IDENTITY_FILE"); ok { ac.IdentityFile = v } + if v, ok := os.LookupEnv("SHELLGUARD_TRANSPORT"); ok { + ac.Transport = v + } c.AutoConnect = ac } diff --git a/server/local.go b/server/local.go new file mode 100644 index 0000000..947ecd4 --- /dev/null +++ b/server/local.go @@ -0,0 +1,78 @@ +package server + +import ( + "bytes" + "context" + "errors" + "os/exec" + "time" + + "github.com/fawdyinc/shellguard/ssh" +) + +// LocalExecutor runs commands on the local machine. +type LocalExecutor struct{} + +// NewLocalExecutor returns a new LocalExecutor. +func NewLocalExecutor() *LocalExecutor { + return &LocalExecutor{} +} + +func (l *LocalExecutor) Connect(_ context.Context, _ ssh.ConnectionParams) error { + return nil +} + +func (l *LocalExecutor) Execute(ctx context.Context, _, command string, timeout time.Duration) (ssh.ExecResult, error) { + return l.run(ctx, command, timeout) +} + +func (l *LocalExecutor) ExecuteRaw(ctx context.Context, _, command string, timeout time.Duration) (ssh.ExecResult, error) { + return l.run(ctx, command, timeout) +} + +func (l *LocalExecutor) SFTPSession(_ string) (ssh.SFTPClient, error) { + return nil, errors.New("SFTP is not supported for local transport") +} + +func (l *LocalExecutor) Disconnect(_ context.Context, _ string) error { + return nil +} + +func (l *LocalExecutor) run(ctx context.Context, command string, timeout time.Duration) (ssh.ExecResult, error) { + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } + + start := time.Now() + cmd := exec.CommandContext(ctx, "sh", "-c", command) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + runtimeMs := int(time.Since(start).Milliseconds()) + + if ctx.Err() != nil { + return ssh.ExecResult{}, ctx.Err() + } + + exitCode := 0 + if err != nil { + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + exitCode = exitErr.ExitCode() + } else { + return ssh.ExecResult{}, err + } + } + + return ssh.ExecResult{ + Stdout: stdout.String(), + Stderr: stderr.String(), + ExitCode: exitCode, + RuntimeMs: runtimeMs, + }, nil +} diff --git a/server/local_test.go b/server/local_test.go new file mode 100644 index 0000000..bee34da --- /dev/null +++ b/server/local_test.go @@ -0,0 +1,60 @@ +package server + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestLocalExecutor_Execute_Success(t *testing.T) { + exec := NewLocalExecutor() + res, err := exec.Execute(context.Background(), "", "echo hello", 5*time.Second) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if got := strings.TrimSpace(res.Stdout); got != "hello" { + t.Fatalf("stdout = %q, want %q", got, "hello") + } + if res.ExitCode != 0 { + t.Fatalf("exit code = %d, want 0", res.ExitCode) + } +} + +func TestLocalExecutor_Execute_NonZeroExit(t *testing.T) { + exec := NewLocalExecutor() + res, err := exec.Execute(context.Background(), "", "false", 5*time.Second) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if res.ExitCode == 0 { + t.Fatal("expected non-zero exit code") + } +} + +func TestLocalExecutor_Execute_Stderr(t *testing.T) { + exec := NewLocalExecutor() + res, err := exec.Execute(context.Background(), "", "echo err >&2", 5*time.Second) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + if got := strings.TrimSpace(res.Stderr); got != "err" { + t.Fatalf("stderr = %q, want %q", got, "err") + } +} + +func TestLocalExecutor_Execute_Timeout(t *testing.T) { + exec := NewLocalExecutor() + _, err := exec.Execute(context.Background(), "", "sleep 10", 100*time.Millisecond) + if err == nil { + t.Fatal("expected timeout error") + } +} + +func TestLocalExecutor_SFTPSession_Error(t *testing.T) { + exec := NewLocalExecutor() + _, err := exec.SFTPSession("any") + if err == nil { + t.Fatal("expected error from SFTPSession") + } +} diff --git a/server/server.go b/server/server.go index fecd0db..4b28efd 100644 --- a/server/server.go +++ b/server/server.go @@ -42,14 +42,49 @@ type Executor interface { Disconnect(ctx context.Context, host string) error } +// TransportType identifies how a server connection is established. +type TransportType string + +const ( + TransportSSH TransportType = "ssh" + TransportLocal TransportType = "local" +) + +// ServerEntry tracks the state of a connected server. +type ServerEntry struct { + Transport TransportType + Connected bool +} + +type ValidateInput struct { + Command string `json:"command" jsonschema:"Shell command or pipeline to validate"` +} + +type ValidateResult struct { + OK bool `json:"ok"` + Reason string `json:"reason,omitempty"` + Command string `json:"command,omitempty"` + Detail string `json:"detail,omitempty"` +} + +type StatusInput struct{} + +type ServerStatus struct { + Connected bool `json:"connected"` + Transport TransportType `json:"transport"` +} + +type StatusResult map[string]ServerStatus + type ProbeResult struct { Missing []string Arch string } type Core struct { - Registry map[string]*manifest.Manifest - Runner Executor + Registry map[string]*manifest.Manifest + Runner Executor + LocalRunner Executor Parse func(string) (*parser.Pipeline, error) Validate func(*parser.Pipeline, map[string]*manifest.Manifest) error @@ -67,7 +102,7 @@ type Core struct { mu sync.RWMutex probeState map[string]*ProbeResult toolkitDeployed map[string]bool - connectedHosts map[string]struct{} + servers map[string]*ServerEntry } type ConnectInput struct { @@ -77,6 +112,7 @@ type ConnectInput struct { IdentityFile string `json:"identity_file,omitempty" jsonschema:"Path to SSH identity file"` Password string `json:"password,omitempty" jsonschema:"SSH password"` Passphrase string `json:"passphrase,omitempty" jsonschema:"Passphrase for encrypted key"` + Transport string `json:"transport,omitempty" jsonschema:"Transport type: ssh (default) or local"` } type ExecuteInput struct { @@ -146,6 +182,7 @@ func NewCore(registry map[string]*manifest.Manifest, runner Executor, logger *sl c := &Core{ Registry: registry, Runner: runner, + LocalRunner: NewLocalExecutor(), logger: logger, Parse: parser.Parse, Validate: validator.ValidatePipeline, @@ -158,7 +195,7 @@ func NewCore(registry map[string]*manifest.Manifest, runner Executor, logger *sl MaxSleepSeconds: 15, probeState: make(map[string]*ProbeResult), toolkitDeployed: make(map[string]bool), - connectedHosts: make(map[string]struct{}), + servers: make(map[string]*ServerEntry), } for _, opt := range opts { opt(c) @@ -171,6 +208,7 @@ func NewCore(registry map[string]*manifest.Manifest, runner Executor, logger *sl // It is safe to call multiple times. func (c *Core) Close(ctx context.Context) error { err := c.Runner.Disconnect(ctx, "") + _ = c.LocalRunner.Disconnect(ctx, "") c.clearHostState("") if err != nil { c.logger.InfoContext(ctx, "close", "outcome", "error", "error", err.Error()) @@ -192,6 +230,17 @@ func (c *Core) Connect(ctx context.Context, in ConnectInput) (map[string]any, er start := time.Now() + if strings.EqualFold(in.Transport, "local") { + c.setConnected(in.Host, TransportLocal, true) + c.logger.InfoContext(ctx, "connect", + "host", in.Host, + "transport", "local", + "outcome", "success", + "duration_ms", time.Since(start).Milliseconds(), + ) + return map[string]any{"ok": true, "host": in.Host, "message": fmt.Sprintf("Connected to %s (local)", in.Host)}, nil + } + params := ssh.ConnectionParams{ Host: in.Host, User: in.User, @@ -209,7 +258,7 @@ func (c *Core) Connect(ctx context.Context, in ConnectInput) (map[string]any, er ) return nil, err } - c.setConnected(in.Host, true) + c.setConnected(in.Host, TransportSSH, true) c.setToolkitDeployed(in.Host, false) c.clearProbeState(in.Host) @@ -274,7 +323,7 @@ func (c *Core) Execute(ctx context.Context, in ExecuteInput) (output.CommandResu reconstructed := c.Reconstruct(pipeline, isPSQL, c.isToolkitDeployed(hostForState)) timeout := c.getPipelineTimeout(pipeline) - execRes, err := c.Runner.Execute(ctx, in.Host, reconstructed, timeout) + execRes, err := c.resolveRunner(hostForState).Execute(ctx, in.Host, reconstructed, timeout) if err != nil { c.logger.InfoContext(ctx, "execute", "command", in.Command, @@ -374,7 +423,7 @@ func (c *Core) Provision(ctx context.Context, in ProvisionInput) (map[string]any } func (c *Core) Disconnect(ctx context.Context, in DisconnectInput) (map[string]any, error) { - if err := c.Runner.Disconnect(ctx, in.Host); err != nil { + if err := c.resolveRunner(in.Host).Disconnect(ctx, in.Host); err != nil { c.logger.InfoContext(ctx, "disconnect", "host", in.Host, "outcome", "error", @@ -558,29 +607,61 @@ func (c *Core) resolveProvisionHost(host string) (string, error) { func (c *Core) ConnectedHostsSnapshot() []string { c.mu.RLock() defer c.mu.RUnlock() - hosts := make([]string, 0, len(c.connectedHosts)) - for host := range c.connectedHosts { - hosts = append(hosts, host) + hosts := make([]string, 0, len(c.servers)) + for host, entry := range c.servers { + if entry.Connected { + hosts = append(hosts, host) + } } sort.Strings(hosts) return hosts } +// ServersSnapshot returns a snapshot of all server entries. +func (c *Core) ServersSnapshot() StatusResult { + c.mu.RLock() + defer c.mu.RUnlock() + result := make(StatusResult, len(c.servers)) + for host, entry := range c.servers { + result[host] = ServerStatus{ + Connected: entry.Connected, + Transport: entry.Transport, + } + } + return result +} + func (c *Core) isConnected(host string) bool { c.mu.RLock() defer c.mu.RUnlock() - _, ok := c.connectedHosts[host] - return ok + entry, ok := c.servers[host] + return ok && entry.Connected } -func (c *Core) setConnected(host string, connected bool) { +func (c *Core) setConnected(host string, transport TransportType, connected bool) { c.mu.Lock() defer c.mu.Unlock() if connected { - c.connectedHosts[host] = struct{}{} + c.servers[host] = &ServerEntry{Transport: transport, Connected: true} return } - delete(c.connectedHosts, host) + delete(c.servers, host) +} + +func (c *Core) getTransport(host string) TransportType { + c.mu.RLock() + defer c.mu.RUnlock() + if entry, ok := c.servers[host]; ok { + return entry.Transport + } + return TransportSSH +} + +func (c *Core) resolveRunner(host string) Executor { + if c.getTransport(host) == TransportLocal { + return c.LocalRunner + } + return c.Runner } func (c *Core) setProbeState(host string, result *ProbeResult) { @@ -635,12 +716,12 @@ func (c *Core) clearHostState(host string) { c.mu.Lock() defer c.mu.Unlock() if host == "" { - clear(c.connectedHosts) + clear(c.servers) clear(c.probeState) clear(c.toolkitDeployed) return } - delete(c.connectedHosts, host) + delete(c.servers, host) delete(c.probeState, host) delete(c.toolkitDeployed, host) } @@ -668,6 +749,31 @@ func collisionSafePath(dir, filename string) (string, error) { return "", fmt.Errorf("filename collision: exhausted %d candidates for %q", maxCollisionRetries, filename) } +func (c *Core) ValidateCommand(_ context.Context, in ValidateInput) (ValidateResult, error) { + if strings.TrimSpace(in.Command) == "" { + return ValidateResult{}, errors.New("command is required") + } + + pipeline, err := c.Parse(in.Command) + if err != nil { + return ValidateResult{OK: false, Reason: err.Error()}, nil + } + + if err := c.Validate(pipeline, c.Registry); err != nil { + var ve *validator.ValidationError + if errors.As(err, &ve) { + return ValidateResult{OK: false, Reason: ve.Message}, nil + } + return ValidateResult{OK: false, Reason: err.Error()}, nil + } + + return ValidateResult{OK: true}, nil +} + +func (c *Core) Status(_ context.Context, _ StatusInput) (StatusResult, error) { + return c.ServersSnapshot(), nil +} + func (c *Core) Sleep(ctx context.Context, in SleepInput) (map[string]any, error) { if in.Seconds <= 0 { return nil, errors.New("seconds must be greater than 0") @@ -735,6 +841,24 @@ func NewMCPServer(core *Core, opts ...ServerOptions) *mcp.Server { return nil, out, err }) + mcp.AddTool(srv, &mcp.Tool{ + Name: "validate", + Description: "Validate a shell command against the security policy without executing it.", + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true, IdempotentHint: true}, + }, func(ctx context.Context, _ *mcp.CallToolRequest, in ValidateInput) (*mcp.CallToolResult, ValidateResult, error) { + out, err := core.ValidateCommand(ctx, in) + return nil, out, err + }) + + mcp.AddTool(srv, &mcp.Tool{ + Name: "status", + Description: "Show connection status for all servers.", + Annotations: &mcp.ToolAnnotations{ReadOnlyHint: true, IdempotentHint: true}, + }, func(ctx context.Context, _ *mcp.CallToolRequest, in StatusInput) (*mcp.CallToolResult, StatusResult, error) { + out, err := core.Status(ctx, in) + return nil, out, err + }) + if !core.DisabledTools["sleep"] { mcp.AddTool(srv, &mcp.Tool{Name: "sleep", Description: fmt.Sprintf("Sleep locally for a specified duration (max %d seconds). Use to wait between checks, e.g. after observing an issue and before re-checking.", core.MaxSleepSeconds)}, func(ctx context.Context, _ *mcp.CallToolRequest, in SleepInput) (*mcp.CallToolResult, map[string]any, error) { diff --git a/server/server_test.go b/server/server_test.go index 2d98f7f..b071561 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1103,3 +1103,198 @@ func TestDownloadFileLogsSuccess(t *testing.T) { } } } + +func TestValidateCommand_OK(t *testing.T) { + core := NewCore(basicRegistry(), newFakeRunner(), nil) + res, err := core.ValidateCommand(context.Background(), ValidateInput{Command: "ls"}) + if err != nil { + t.Fatalf("ValidateCommand() error = %v", err) + } + if !res.OK { + t.Fatalf("expected OK=true, got reason=%q", res.Reason) + } +} + +func TestValidateCommand_ParseError(t *testing.T) { + core := NewCore(basicRegistry(), newFakeRunner(), nil) + core.Parse = func(_ string) (*parser.Pipeline, error) { + return nil, &parser.ParseError{Message: "syntax error"} + } + res, err := core.ValidateCommand(context.Background(), ValidateInput{Command: "$(evil)"}) + if err != nil { + t.Fatalf("ValidateCommand() error = %v", err) + } + if res.OK { + t.Fatal("expected OK=false for parse error") + } + if res.Reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestValidateCommand_ValidationError(t *testing.T) { + core := NewCore(basicRegistry(), newFakeRunner(), nil) + res, err := core.ValidateCommand(context.Background(), ValidateInput{Command: "rm -rf /"}) + if err != nil { + t.Fatalf("ValidateCommand() error = %v", err) + } + if res.OK { + t.Fatal("expected OK=false for denied command") + } + if res.Reason == "" { + t.Fatal("expected non-empty reason") + } +} + +func TestValidateCommand_Empty(t *testing.T) { + core := NewCore(basicRegistry(), newFakeRunner(), nil) + _, err := core.ValidateCommand(context.Background(), ValidateInput{Command: ""}) + if err == nil { + t.Fatal("expected error for empty command") + } +} + +func TestStatus_Empty(t *testing.T) { + core := NewCore(basicRegistry(), newFakeRunner(), nil) + res, err := core.Status(context.Background(), StatusInput{}) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + if len(res) != 0 { + t.Fatalf("expected empty status, got %v", res) + } +} + +func TestStatus_WithConnections(t *testing.T) { + runner := newFakeRunner() + core := NewCore(basicRegistry(), runner, nil) + + // Connect SSH host + if _, err := core.Connect(context.Background(), ConnectInput{Host: "ssh-host"}); err != nil { + t.Fatalf("Connect(ssh) error = %v", err) + } + // Connect local host + if _, err := core.Connect(context.Background(), ConnectInput{Host: "local-host", Transport: "local"}); err != nil { + t.Fatalf("Connect(local) error = %v", err) + } + + res, err := core.Status(context.Background(), StatusInput{}) + if err != nil { + t.Fatalf("Status() error = %v", err) + } + if len(res) != 2 { + t.Fatalf("expected 2 entries, got %d", len(res)) + } + if s, ok := res["ssh-host"]; !ok || !s.Connected || s.Transport != TransportSSH { + t.Fatalf("unexpected ssh-host status: %+v", res["ssh-host"]) + } + if s, ok := res["local-host"]; !ok || !s.Connected || s.Transport != TransportLocal { + t.Fatalf("unexpected local-host status: %+v", res["local-host"]) + } +} + +func TestConnect_Local(t *testing.T) { + runner := newFakeRunner() + core := NewCore(basicRegistry(), runner, nil) + + out, err := core.Connect(context.Background(), ConnectInput{Host: "my-local", Transport: "local"}) + if err != nil { + t.Fatalf("Connect(local) error = %v", err) + } + if out["ok"] != true { + t.Fatalf("expected ok=true, got %v", out) + } + + // SSH runner should not have been called + runner.mu.Lock() + called := runner.connectCalled + runner.mu.Unlock() + if called { + t.Fatal("SSH runner.Connect should not be called for local transport") + } + + // Host should appear in connected hosts + if !core.isConnected("my-local") { + t.Fatal("expected my-local to be connected") + } + if core.getTransport("my-local") != TransportLocal { + t.Fatalf("expected local transport, got %s", core.getTransport("my-local")) + } +} + +func TestExecute_Local(t *testing.T) { + runner := newFakeRunner() + core := NewCore(basicRegistry(), runner, nil) + + // Connect local + if _, err := core.Connect(context.Background(), ConnectInput{Host: "local-box", Transport: "local"}); err != nil { + t.Fatalf("Connect(local) error = %v", err) + } + + // Override parse/validate/reconstruct so we can test execution + core.Parse = func(_ string) (*parser.Pipeline, error) { + return &parser.Pipeline{Segments: []parser.PipelineSegment{{Command: "ls"}}}, nil + } + core.Validate = func(_ *parser.Pipeline, _ map[string]*manifest.Manifest) error { return nil } + core.Reconstruct = func(_ *parser.Pipeline, _, _ bool) string { return "echo local-test" } + + res, err := core.Execute(context.Background(), ExecuteInput{Host: "local-box", Command: "ls"}) + if err != nil { + t.Fatalf("Execute() error = %v", err) + } + + // SSH runner should NOT have received the execute call + runner.mu.Lock() + sshExecuted := runner.executeCalled + runner.mu.Unlock() + if sshExecuted { + t.Fatal("SSH runner should not receive execute calls for local transport") + } + + // Local executor should have run the command + if !strings.Contains(res.Stdout, "local-test") { + t.Fatalf("expected stdout to contain 'local-test', got %q", res.Stdout) + } +} + +func TestNewMCPServerRegistersValidateAndStatus(t *testing.T) { + ctx := context.Background() + core := NewCore(basicRegistry(), newFakeRunner(), nil) + s := NewMCPServer(core) + c := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, nil) + t1, t2 := mcp.NewInMemoryTransports() + ss, err := s.Connect(ctx, t1, nil) + if err != nil { + t.Fatalf("server connect: %v", err) + } + defer func() { _ = ss.Close() }() + cs, err := c.Connect(ctx, t2, nil) + if err != nil { + t.Fatalf("client connect: %v", err) + } + defer func() { _ = cs.Close() }() + + found := map[string]*mcp.Tool{} + for tool, err := range cs.Tools(ctx, nil) { + if err != nil { + t.Fatalf("tools iterator error: %v", err) + } + found[tool.Name] = tool + } + + for _, name := range []string{"validate", "status"} { + tool, ok := found[name] + if !ok { + t.Fatalf("missing tool %q", name) + } + if tool.Annotations == nil { + t.Fatalf("tool %q missing annotations", name) + } + if !tool.Annotations.ReadOnlyHint { + t.Fatalf("tool %q should be read-only", name) + } + if !tool.Annotations.IdempotentHint { + t.Fatalf("tool %q should be idempotent", name) + } + } +}