diff --git a/common/pkg/ssh/connection_golang.go b/common/pkg/ssh/connection_golang.go index c8e3c5bd85..cc409fd388 100644 --- a/common/pkg/ssh/connection_golang.go +++ b/common/pkg/ssh/connection_golang.go @@ -192,6 +192,87 @@ func golangConnectionScp(options ConnectionScpOptions) (*ConnectionScpReport, er return &ConnectionScpReport{Response: remote.Name()}, nil } +type golangExecOutputReader struct { + stdout io.Reader + sess *ssh.Session + client *ssh.Client + stderr *bytes.Buffer + waitCh <-chan error +} + +func (r *golangExecOutputReader) Read(p []byte) (int, error) { + return r.stdout.Read(p) +} + +func (r *golangExecOutputReader) Close() error { + r.sess.Close() + err := <-r.waitCh + // safe to close the connection now that the session is done + // and the exit status has been received. + r.client.Close() + if err != nil { + return fmt.Errorf("%v: %w", r.stderr.String(), err) + } + return nil +} + +func golangConnectionExecWithOutput(options ConnectionExecOptions, input io.Reader) (io.ReadCloser, error) { + if !strings.HasPrefix(options.Host, "ssh://") { + options.Host = "ssh://" + options.Host + } + _, uri, err := Validate(options.User, options.Host, options.Port, options.Identity) + if err != nil { + return nil, err + } + + cfg, err := ValidateAndConfigure(uri, options.Identity, false) + if err != nil { + return nil, err + } + dialAdd, err := ssh.Dial("tcp", uri.Host, cfg) + if err != nil { + return nil, fmt.Errorf("failed to connect: %w", err) + } + + sess, err := dialAdd.NewSession() + if err != nil { + dialAdd.Close() + return nil, err + } + + stderr := &bytes.Buffer{} + sess.Stderr = stderr + if input != nil { + sess.Stdin = input + } + + stdout, err := sess.StdoutPipe() + if err != nil { + sess.Close() + dialAdd.Close() + return nil, err + } + + if err := sess.Start(strings.Join(options.Args, " ")); err != nil { + sess.Close() + dialAdd.Close() + return nil, err + } + + waitCh := make(chan error, 1) + go func() { + waitCh <- sess.Wait() + }() + + return &golangExecOutputReader{ + stdout: stdout, + sess: sess, + client: dialAdd, + stderr: stderr, + waitCh: waitCh, + }, nil +} + // ExecRemoteCommand takes a ssh client connection and a command to run and executes the // command on the specified client. The function returns the Stdout from the client or the Stderr. func ExecRemoteCommand(dial *ssh.Client, run string) ([]byte, error) { diff --git a/common/pkg/ssh/connection_native.go b/common/pkg/ssh/connection_native.go index a274db88ab..6ee30cbb45 100644 --- a/common/pkg/ssh/connection_native.go +++ b/common/pkg/ssh/connection_native.go @@ -101,19 +101,17 @@ func nativeConnectionCreate(options ConnectionCreateOptions) error { }) } -func nativeConnectionExec(options ConnectionExecOptions, input io.Reader) (*ConnectionExecReport, error) { +func nativePrepareSSHCmd(options ConnectionExecOptions, input io.Reader) (*exec.Cmd, *bytes.Buffer, error) { dst, uri, err := Validate(options.User, options.Host, options.Port, options.Identity) if err != nil { - return nil, err + return nil, nil, err } ssh, err := exec.LookPath("ssh") if err != nil { - return nil, err + return nil, nil, err } - output := &bytes.Buffer{} - errors := &bytes.Buffer{} if host, _, ok := strings.Cut(uri.Host, "/run"); ok { uri.Host = host } @@ -121,7 +119,7 @@ func nativeConnectionExec(options ConnectionExecOptions, input io.Reader) (*Conn options.Args = append([]string{uri.User.String() + "@" + uri.Hostname()}, options.Args...) conf, err := config.Default() if err != nil { - return nil, err + return nil, nil, err } args := []string{} @@ -132,19 +130,67 @@ func nativeConnectionExec(options ConnectionExecOptions, input io.Reader) (*Conn args = append(args, "-F", conf.Engine.SSHConfig) } args = append(args, options.Args...) - info := exec.Command(ssh, args...) - info.Stdout = output - info.Stderr = errors + + stderr := &bytes.Buffer{} + cmd := exec.Command(ssh, args...) + cmd.Stderr = stderr if input != nil { - info.Stdin = input + cmd.Stdin = input } - err = info.Run() + + return cmd, stderr, nil +} + +func nativeConnectionExec(options ConnectionExecOptions, input io.Reader) (*ConnectionExecReport, error) { + cmd, stderr, err := nativePrepareSSHCmd(options, input) if err != nil { return nil, err } + + output := &bytes.Buffer{} + cmd.Stdout = output + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("%v: %w", stderr.String(), err) + } return &ConnectionExecReport{Response: output.String()}, nil } +type nativeExecOutputReader struct { + stdout io.ReadCloser + cmd *exec.Cmd + stderr *bytes.Buffer +} + +func (r *nativeExecOutputReader) Read(p []byte) (int, error) { + return r.stdout.Read(p) +} + +func (r *nativeExecOutputReader) Close() error { + err := r.cmd.Wait() + if err != nil { + return fmt.Errorf("%v: %w", r.stderr.String(), err) + } + return nil +} + +func nativeConnectionExecWithOutput(options ConnectionExecOptions, input io.Reader) (io.ReadCloser, error) { + cmd, stderr, err := nativePrepareSSHCmd(options, input) + if err != nil { + return nil, err + } + + stdout, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + + if err := cmd.Start(); err != nil { + return nil, err + } + + return &nativeExecOutputReader{stdout: stdout, cmd: cmd, stderr: stderr}, nil +} + func nativeConnectionScp(options ConnectionScpOptions) (*ConnectionScpReport, error) { host, remotePath, localPath, swap, err := ParseScpArgs(options) if err != nil { diff --git a/common/pkg/ssh/ssh.go b/common/pkg/ssh/ssh.go index 6e8a923ee4..8bd07ae642 100644 --- a/common/pkg/ssh/ssh.go +++ b/common/pkg/ssh/ssh.go @@ -48,6 +48,13 @@ func ExecWithInput(options *ConnectionExecOptions, kind EngineMode, input io.Rea return rep.Response, nil } +func ExecWithOutput(options *ConnectionExecOptions, kind EngineMode, input io.Reader) (io.ReadCloser, error) { + if kind == NativeMode { + return nativeConnectionExecWithOutput(*options, input) + } + return golangConnectionExecWithOutput(*options, input) +} + func Scp(options *ConnectionScpOptions, kind EngineMode) (string, error) { var rep *ConnectionScpReport var err error diff --git a/common/pkg/ssh/ssh_test.go b/common/pkg/ssh/ssh_test.go index 969ad971ad..de32d90681 100644 --- a/common/pkg/ssh/ssh_test.go +++ b/common/pkg/ssh/ssh_test.go @@ -68,6 +68,27 @@ func TestExecWithInput(t *testing.T) { require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures") } +func TestExecWithOutput(t *testing.T) { + input, err := os.Open("/etc/fstab") + require.NoError(t, err) + defer input.Close() + + options := ConnectionExecOptions{ + Port: 22, + Host: "localhost", + Args: []string{"md5sum"}, + } + + // Native mode: Start() succeeds but the connection fails when the process runs. + // The error surfaces when the caller closes the reader. + rc, err := ExecWithOutput(&options, NativeMode, input) + require.NoError(t, err) + require.Error(t, rc.Close(), "exit status 255") + + _, err = ExecWithOutput(&options, GolangMode, input) + require.Error(t, err, "failed to connect: ssh: handshake failed: ssh: disconnect, reason 2: Too many authentication failures") +} + func TestDial(t *testing.T) { options := ConnectionDialOptions{ Port: 22,