@@ -15,11 +15,11 @@ import (
15
15
// NewDialer returns a new Dialer that dials through the provided
16
16
// proxy server's network and address.
17
17
func NewDialer (addr string ) (* Dialer , error ) {
18
- host , config , err := clientConfig (addr )
18
+ config , err := parseClientConfig (addr )
19
19
if err != nil {
20
20
return nil , err
21
21
}
22
- return NewDialerWithConfig (host , config )
22
+ return NewDialerWithConfig (config . host , config . clientConfig )
23
23
}
24
24
25
25
func NewDialerWithConfig (host string , config * ssh.ClientConfig ) (* Dialer , error ) {
@@ -29,10 +29,15 @@ func NewDialerWithConfig(host string, config *ssh.ClientConfig) (*Dialer, error)
29
29
}, nil
30
30
}
31
31
32
- func clientConfig (addr string ) (host string , config * ssh.ClientConfig , err error ) {
32
+ type clientConfig struct {
33
+ host string
34
+ clientConfig * ssh.ClientConfig
35
+ }
36
+
37
+ func parseClientConfig (addr string ) (* clientConfig , error ) {
33
38
ur , err := url .Parse (addr )
34
39
if err != nil {
35
- return "" , nil , err
40
+ return nil , err
36
41
}
37
42
38
43
user := ""
@@ -43,7 +48,7 @@ func clientConfig(addr string) (host string, config *ssh.ClientConfig, err error
43
48
pwd , isPwd = ur .User .Password ()
44
49
}
45
50
46
- config = & ssh.ClientConfig {
51
+ config : = & ssh.ClientConfig {
47
52
User : user ,
48
53
HostKeyCallback : ssh .InsecureIgnoreHostKey (),
49
54
}
@@ -54,23 +59,26 @@ func clientConfig(addr string) (host string, config *ssh.ClientConfig, err error
54
59
55
60
identityDatas , err := getQuery (ur .Query ()["identity_data" ], ur .Query ()["identity_file" ])
56
61
if err != nil {
57
- return "" , nil , err
62
+ return nil , err
58
63
}
59
64
for _ , data := range identityDatas {
60
65
signer , err := ssh .ParsePrivateKey (data )
61
66
if err != nil {
62
- return "" , nil , err
67
+ return nil , err
63
68
}
64
69
config .Auth = append (config .Auth , ssh .PublicKeys (signer ))
65
70
}
66
71
67
- host = ur .Hostname ()
72
+ host : = ur .Hostname ()
68
73
port := ur .Port ()
69
74
if port == "" {
70
75
port = "22"
71
76
}
72
- host = net .JoinHostPort (host , port )
73
- return host , config , nil
77
+
78
+ return & clientConfig {
79
+ clientConfig : config ,
80
+ host : net .JoinHostPort (host , port ),
81
+ }, nil
74
82
}
75
83
76
84
type Dialer struct {
@@ -79,20 +87,22 @@ type Dialer struct {
79
87
// ProxyDial specifies the optional dial function for
80
88
// establishing the transport connection.
81
89
ProxyDial func (context.Context , string , string ) (net.Conn , error )
82
- sshCli * ssh.Client
83
- host string
84
- config * ssh.ClientConfig
90
+
91
+ host string
92
+ config * ssh.ClientConfig
93
+
94
+ pool sync.Pool
85
95
}
86
96
87
97
func (d * Dialer ) Close () error {
88
- d .mut .Lock ()
89
- defer d .mut .Unlock ()
90
- if d .sshCli == nil {
91
- return nil
98
+ for {
99
+ a := d .pool .Get ()
100
+ if a == nil {
101
+ break
102
+ }
103
+ a .(* ssh.Client ).Close ()
92
104
}
93
- err := d .sshCli .Close ()
94
- d .sshCli = nil
95
- return err
105
+ return nil
96
106
}
97
107
98
108
func (d * Dialer ) proxyDial (ctx context.Context , network , address string ) (net.Conn , error ) {
@@ -105,12 +115,36 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
105
115
}
106
116
107
117
func (d * Dialer ) SSHClient (ctx context.Context ) (* ssh.Client , error ) {
118
+ return d .GetClient (ctx )
119
+ }
120
+
121
+ func (d * Dialer ) GetClient (ctx context.Context ) (* ssh.Client , error ) {
122
+ a := d .pool .Get ()
123
+ if a != nil {
124
+ return a .(* ssh.Client ), nil
125
+ }
126
+
108
127
d .mut .Lock ()
109
128
defer d .mut .Unlock ()
110
- cli := d .sshCli
111
- if cli != nil {
112
- return cli , nil
129
+
130
+ a = d .pool .Get ()
131
+ if a != nil {
132
+ return a .(* ssh.Client ), nil
133
+ }
134
+
135
+ cli , err := d .sshClient (ctx )
136
+ if err != nil {
137
+ return nil , err
113
138
}
139
+
140
+ return cli , nil
141
+ }
142
+
143
+ func (d * Dialer ) PutClient (cli * ssh.Client ) {
144
+ d .pool .Put (cli )
145
+ }
146
+
147
+ func (d * Dialer ) sshClient (ctx context.Context ) (* ssh.Client , error ) {
114
148
conn , err := d .proxyDial (ctx , "tcp" , d .host )
115
149
if err != nil {
116
150
return nil , err
@@ -120,39 +154,42 @@ func (d *Dialer) SSHClient(ctx context.Context) (*ssh.Client, error) {
120
154
if err != nil {
121
155
return nil , err
122
156
}
123
- cli = ssh .NewClient (con , chans , reqs )
124
- d .sshCli = cli
125
- return cli , nil
157
+ return ssh .NewClient (con , chans , reqs ), nil
126
158
}
127
159
128
- func ( d * Dialer ) CommandDialContext ( ctx context. Context , name string , args ... string ) (net. Conn , error ) {
129
- cmd := make ([]string , 0 , len (args )+ 1 )
130
- cmd = append (cmd , name )
160
+ func buildCmd ( name string , args ... string ) string {
161
+ cmds := make ([]string , 0 , len (args )+ 1 )
162
+ cmds = append (cmds , name )
131
163
for _ , arg := range args {
132
- cmd = append (cmd , strconv .Quote (arg ))
164
+ cmds = append (cmds , strconv .Quote (arg ))
133
165
}
134
- return d . commandDialContext ( ctx , strings .Join (cmd , " " ), 1 )
166
+ return strings .Join (cmds , " " )
135
167
}
136
168
137
- func (d * Dialer ) commandDialContext (ctx context.Context , cmd string , retry int ) (net.Conn , error ) {
138
- cli , err := d .SSHClient (ctx )
169
+ func (d * Dialer ) CommandDialContext (ctx context.Context , name string , args ... string ) (net.Conn , error ) {
170
+ cli , err := d .GetClient (ctx )
139
171
if err != nil {
140
172
return nil , err
141
173
}
142
174
sess , err := cli .NewSession ()
143
175
if err != nil {
176
+ if isSSHError (err ) {
177
+ d .PutClient (cli )
178
+ } else {
179
+ cli .Close ()
180
+ }
144
181
return nil , err
145
182
}
183
+ defer d .PutClient (cli )
184
+
146
185
conn1 , conn2 := net .Pipe ()
147
186
sess .Stdin = conn1
148
187
sess .Stdout = conn1
149
188
sess .Stderr = os .Stderr
189
+
190
+ cmd := buildCmd (name , args ... )
150
191
err = sess .Start (cmd )
151
192
if err != nil {
152
- if retry != 0 {
153
- d .Close ()
154
- return d .commandDialContext (ctx , cmd , retry - 1 )
155
- }
156
193
return nil , err
157
194
}
158
195
ctx , cancel := context .WithCancel (ctx )
@@ -180,25 +217,42 @@ func (d *Dialer) commandDialContext(ctx context.Context, cmd string, retry int)
180
217
}
181
218
182
219
func (d * Dialer ) DialContext (ctx context.Context , network , address string ) (net.Conn , error ) {
183
- return d .dialContext (ctx , network , address , 1 )
184
- }
185
-
186
- func (d * Dialer ) dialContext (ctx context.Context , network , address string , retry int ) (net.Conn , error ) {
187
- cli , err := d .SSHClient (ctx )
220
+ cli , err := d .GetClient (ctx )
188
221
if err != nil {
189
222
return nil , err
190
223
}
224
+
191
225
conn , err := cli .DialContext (ctx , network , address )
192
226
if err != nil {
193
- if ctx .Err () != nil {
194
- return nil , err
227
+ if isSSHError (err ) {
228
+ d .PutClient (cli )
229
+ } else {
230
+ cli .Close ()
195
231
}
196
- if retry != 0 {
197
- d .Close ()
198
- return d .dialContext (ctx , network , address , retry - 1 )
232
+ return nil , err
233
+ }
234
+
235
+ d .PutClient (cli )
236
+ return conn , nil
237
+ }
238
+
239
+ func (d * Dialer ) Dial (network , address string ) (net.Conn , error ) {
240
+ cli , err := d .GetClient (context .Background ())
241
+ if err != nil {
242
+ return nil , err
243
+ }
244
+
245
+ conn , err := cli .Dial (network , address )
246
+ if err != nil {
247
+ if isSSHError (err ) {
248
+ d .PutClient (cli )
249
+ } else {
250
+ cli .Close ()
199
251
}
200
252
return nil , err
201
253
}
254
+
255
+ d .PutClient (cli )
202
256
return conn , nil
203
257
}
204
258
@@ -209,21 +263,44 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
209
263
address = net .JoinHostPort ("0.0.0.0" , port )
210
264
}
211
265
}
212
- return d .listen (ctx , network , address , 1 )
213
- }
214
266
215
- func (d * Dialer ) listen (ctx context.Context , network , address string , retry int ) (net.Listener , error ) {
216
- cli , err := d .SSHClient (ctx )
267
+ cli , err := d .GetClient (ctx )
217
268
if err != nil {
218
269
return nil , err
219
270
}
271
+
220
272
listener , err := cli .Listen (network , address )
221
273
if err != nil {
222
- if retry != 0 {
223
- d .Close ()
224
- return d .listen (ctx , network , address , retry - 1 )
274
+ if isSSHError (err ) {
275
+ d .PutClient (cli )
276
+ } else {
277
+ cli .Close ()
225
278
}
226
279
return nil , err
227
280
}
281
+
282
+ listener = & listenerWithCleanup {
283
+ Listener : listener ,
284
+ cleanup : func () {
285
+ d .PutClient (cli )
286
+ },
287
+ }
288
+
228
289
return listener , nil
229
290
}
291
+
292
+ type listenerWithCleanup struct {
293
+ net.Listener
294
+ cleanup func ()
295
+ }
296
+
297
+ func (l * listenerWithCleanup ) Close () error {
298
+ err := l .Listener .Close ()
299
+ l .cleanup ()
300
+ return err
301
+ }
302
+
303
+ func isSSHError (err error ) bool {
304
+ msg := err .Error ()
305
+ return strings .Contains (msg , "ssh: " )
306
+ }
0 commit comments