2
2
package ssh
3
3
4
4
import (
5
+ "context"
5
6
"fmt"
6
7
"reflect"
7
8
"strconv"
@@ -11,6 +12,7 @@ import (
11
12
12
13
"github.com/kevinburke/ssh_config"
13
14
"golang.org/x/crypto/ssh"
15
+ "golang.org/x/net/proxy"
14
16
)
15
17
16
18
// DefaultClient is the default SSH client.
@@ -115,7 +117,7 @@ func (c *command) connect() error {
115
117
116
118
overrideConfig (c .config , config )
117
119
118
- c .client , err = ssh . Dial ("tcp" , c .getHostWithPort (), config )
120
+ c .client , err = dial ("tcp" , c .getHostWithPort (), config )
119
121
if err != nil {
120
122
return err
121
123
}
@@ -130,6 +132,29 @@ func (c *command) connect() error {
130
132
return nil
131
133
}
132
134
135
+ func dial (network , addr string , config * ssh.ClientConfig ) (* ssh.Client , error ) {
136
+ var (
137
+ ctx = context .Background ()
138
+ cancel context.CancelFunc
139
+ )
140
+ if config .Timeout > 0 {
141
+ ctx , cancel = context .WithTimeout (ctx , config .Timeout )
142
+ } else {
143
+ ctx , cancel = context .WithCancel (ctx )
144
+ }
145
+ defer cancel ()
146
+
147
+ conn , err := proxy .Dial (ctx , network , addr )
148
+ if err != nil {
149
+ return nil , err
150
+ }
151
+ c , chans , reqs , err := ssh .NewClientConn (conn , addr , config )
152
+ if err != nil {
153
+ return nil , err
154
+ }
155
+ return ssh .NewClient (c , chans , reqs ), nil
156
+ }
157
+
133
158
func (c * command ) getHostWithPort () string {
134
159
if addr , found := c .doGetHostWithPortFromSSHConfig (); found {
135
160
return addr
0 commit comments