diff --git a/go.mod b/go.mod index 5189264e50..8032023fa3 100644 --- a/go.mod +++ b/go.mod @@ -78,6 +78,7 @@ require ( github.com/KyleBanks/depth v1.2.1 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.0.0 // indirect + github.com/UserExistsError/conpty v0.1.4 // indirect github.com/akutz/memconn v0.1.0 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect diff --git a/go.sum b/go.sum index b2e9c69727..4426403582 100644 --- a/go.sum +++ b/go.sum @@ -626,6 +626,8 @@ github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/ProtonMail/go-crypto v1.0.0 h1:LRuvITjQWX+WIfr930YHG2HNfjR1uOfyf5vE0kC2U78= github.com/ProtonMail/go-crypto v1.0.0/go.mod h1:EjAoLdwvbIOoOQr3ihjnSoLZRtE8azugULFRteWMNc0= +github.com/UserExistsError/conpty v0.1.4 h1:+3FhJhiqhyEJa+K5qaK3/w6w+sN3Nh9O9VbJyBS02to= +github.com/UserExistsError/conpty v0.1.4/go.mod h1:PDglKIkX3O/2xVk0MV9a6bCWxRmPVfxqZoTG/5sSd9I= github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= diff --git a/pkg/agent/ssh/server.go b/pkg/agent/ssh/server.go index 4585e92d0b..f1ff3ef2a5 100644 --- a/pkg/agent/ssh/server.go +++ b/pkg/agent/ssh/server.go @@ -8,9 +8,8 @@ import ( "io" "os" "os/exec" - "strings" + "runtime" - "github.com/creack/pty" "github.com/daytonaio/daytona/pkg/agent/ssh/config" "github.com/daytonaio/daytona/pkg/common" "github.com/gliderlabs/ssh" @@ -102,12 +101,19 @@ func (s *Server) handlePty(session ssh.Session, ptyReq ssh.Pty, winCh <-chan ssh cmd.Env = append(cmd.Env, fmt.Sprintf("TERM=%s", ptyReq.Term)) cmd.Env = append(cmd.Env, os.Environ()...) cmd.Env = append(cmd.Env, fmt.Sprintf("SHELL=%s", shell)) - f, err := pty.Start(cmd) - if err != nil { - log.Errorf("Unable to start command: %v", err) - return - } + var f io.ReadWriteCloser + if runtime.GOOS == "windows" { + output := Start(shell) + if output != nil { + f = output + } + } else { + output := Start(cmd) + if output != nil { + f = output + } + } go func() { for win := range winCh { SetPtySize(f, win) @@ -125,7 +131,7 @@ func (s *Server) handleNonPty(session ssh.Session) { args = append([]string{"-c"}, session.RawCommand()) } - cmd := exec.Command("/bin/sh", args...) + cmd := exec.Command("sh", args...) cmd.Env = append(cmd.Env, os.Environ()...) @@ -196,33 +202,6 @@ func (s *Server) handleNonPty(session ssh.Session) { } func (s *Server) getShell() string { - out, err := exec.Command("sh", "-c", "grep '^[^#]' /etc/shells").Output() - if err != nil { - return "sh" - } - - if strings.Contains(string(out), "/usr/bin/zsh") { - return "/usr/bin/zsh" - } - - if strings.Contains(string(out), "/bin/zsh") { - return "/bin/zsh" - } - - if strings.Contains(string(out), "/usr/bin/bash") { - return "/usr/bin/bash" - } - - if strings.Contains(string(out), "/bin/bash") { - return "/bin/bash" - } - - shellEnv, shellSet := os.LookupEnv("SHELL") - - if shellSet { - return shellEnv - } - return "sh" } diff --git a/pkg/agent/ssh/server_unix.go b/pkg/agent/ssh/server_unix.go index 6c15478bdb..198cccfd4d 100644 --- a/pkg/agent/ssh/server_unix.go +++ b/pkg/agent/ssh/server_unix.go @@ -7,16 +7,35 @@ package ssh import ( "os" + "os/exec" "syscall" "unsafe" + "github.com/creack/pty" "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" ) -func SetPtySize(f *os.File, win ssh.Window) { - syscall.Syscall(syscall.SYS_IOCTL, f.Fd(), uintptr(syscall.TIOCSWINSZ), - uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) +func Start(cmd interface{}) *os.File { + if command, ok := cmd.(*exec.Cmd); ok { + f, err := pty.Start(command) + if err != nil { + log.Errorf("Unable to start PTY: %v", err) + return nil + } + return f + } + return nil +} + +func SetPtySize(f interface{}, win ssh.Window) { + if file, ok := f.(*os.File); ok { + syscall.Syscall(syscall.SYS_IOCTL, file.Fd(), uintptr(syscall.TIOCSWINSZ), + uintptr(unsafe.Pointer(&struct{ h, w, x, y uint16 }{uint16(win.Height), uint16(win.Width), 0, 0}))) + } else { + log.Errorf("Unable to resize PTY") + } } func OsSignalFrom(sig ssh.Signal) os.Signal { diff --git a/pkg/agent/ssh/server_win.go b/pkg/agent/ssh/server_win.go index 6eb271eb9b..747957878d 100644 --- a/pkg/agent/ssh/server_win.go +++ b/pkg/agent/ssh/server_win.go @@ -8,9 +8,10 @@ package ssh import ( "os" "syscall" - "unsafe" + "github.com/UserExistsError/conpty" "github.com/gliderlabs/ssh" + log "github.com/sirupsen/logrus" ) var ( @@ -18,15 +19,25 @@ var ( setConsoleWindowInfo = kernel32.NewProc("SetConsoleWindowInfo") ) -func SetPtySize(f *os.File, win ssh.Window) { - handle := f.Fd() - var rect struct { - Left, Top, Right, Bottom int16 +func Start(cmd interface{}) *conpty.ConPty { + if shell, ok := cmd.(string); ok { + f, err := conpty.Start(shell) + if err != nil { + log.Errorf("Unable to start ConPTY: %v", err) + return nil + } + return f } - rect.Right = int16(win.Width - 1) - rect.Bottom = int16(win.Height - 1) - setConsoleWindowInfo.Call(uintptr(handle), uintptr(1), uintptr(unsafe.Pointer(&rect))) + return nil +} + +func SetPtySize(f interface{}, win ssh.Window) { + if cpty, ok := f.(*conpty.ConPty); ok { + cpty.Resize(win.Width, win.Height) + } else { + log.Errorf("Unable to resize ConPTY") + } } func OsSignalFrom(sig ssh.Signal) os.Signal {