Skip to content

Commit 79458df

Browse files
committed
Use connect pool
1 parent c1db717 commit 79458df

File tree

1 file changed

+131
-54
lines changed

1 file changed

+131
-54
lines changed

client.go

Lines changed: 131 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@ import (
1515
// NewDialer returns a new Dialer that dials through the provided
1616
// proxy server's network and address.
1717
func NewDialer(addr string) (*Dialer, error) {
18-
host, config, err := clientConfig(addr)
18+
config, err := parseClientConfig(addr)
1919
if err != nil {
2020
return nil, err
2121
}
22-
return NewDialerWithConfig(host, config)
22+
return NewDialerWithConfig(config.host, config.clientConfig)
2323
}
2424

2525
func NewDialerWithConfig(host string, config *ssh.ClientConfig) (*Dialer, error) {
@@ -29,10 +29,15 @@ func NewDialerWithConfig(host string, config *ssh.ClientConfig) (*Dialer, error)
2929
}, nil
3030
}
3131

32-
func clientConfig(addr string) (host string, config *ssh.ClientConfig, err error) {
32+
type clientConfig struct {
33+
host string
34+
clientConfig *ssh.ClientConfig
35+
}
36+
37+
func parseClientConfig(addr string) (*clientConfig, error) {
3338
ur, err := url.Parse(addr)
3439
if err != nil {
35-
return "", nil, err
40+
return nil, err
3641
}
3742

3843
user := ""
@@ -43,7 +48,7 @@ func clientConfig(addr string) (host string, config *ssh.ClientConfig, err error
4348
pwd, isPwd = ur.User.Password()
4449
}
4550

46-
config = &ssh.ClientConfig{
51+
config := &ssh.ClientConfig{
4752
User: user,
4853
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
4954
}
@@ -54,23 +59,26 @@ func clientConfig(addr string) (host string, config *ssh.ClientConfig, err error
5459

5560
identityDatas, err := getQuery(ur.Query()["identity_data"], ur.Query()["identity_file"])
5661
if err != nil {
57-
return "", nil, err
62+
return nil, err
5863
}
5964
for _, data := range identityDatas {
6065
signer, err := ssh.ParsePrivateKey(data)
6166
if err != nil {
62-
return "", nil, err
67+
return nil, err
6368
}
6469
config.Auth = append(config.Auth, ssh.PublicKeys(signer))
6570
}
6671

67-
host = ur.Hostname()
72+
host := ur.Hostname()
6873
port := ur.Port()
6974
if port == "" {
7075
port = "22"
7176
}
72-
host = net.JoinHostPort(host, port)
73-
return host, config, nil
77+
78+
return &clientConfig{
79+
clientConfig: config,
80+
host: net.JoinHostPort(host, port),
81+
}, nil
7482
}
7583

7684
type Dialer struct {
@@ -79,20 +87,22 @@ type Dialer struct {
7987
// ProxyDial specifies the optional dial function for
8088
// establishing the transport connection.
8189
ProxyDial func(context.Context, string, string) (net.Conn, error)
82-
sshCli *ssh.Client
83-
host string
84-
config *ssh.ClientConfig
90+
91+
host string
92+
config *ssh.ClientConfig
93+
94+
pool sync.Pool
8595
}
8696

8797
func (d *Dialer) Close() error {
88-
d.mut.Lock()
89-
defer d.mut.Unlock()
90-
if d.sshCli == nil {
91-
return nil
98+
for {
99+
a := d.pool.Get()
100+
if a == nil {
101+
break
102+
}
103+
a.(*ssh.Client).Close()
92104
}
93-
err := d.sshCli.Close()
94-
d.sshCli = nil
95-
return err
105+
return nil
96106
}
97107

98108
func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Conn, error) {
@@ -105,12 +115,36 @@ func (d *Dialer) proxyDial(ctx context.Context, network, address string) (net.Co
105115
}
106116

107117
func (d *Dialer) SSHClient(ctx context.Context) (*ssh.Client, error) {
118+
return d.GetClient(ctx)
119+
}
120+
121+
func (d *Dialer) GetClient(ctx context.Context) (*ssh.Client, error) {
122+
a := d.pool.Get()
123+
if a != nil {
124+
return a.(*ssh.Client), nil
125+
}
126+
108127
d.mut.Lock()
109128
defer d.mut.Unlock()
110-
cli := d.sshCli
111-
if cli != nil {
112-
return cli, nil
129+
130+
a = d.pool.Get()
131+
if a != nil {
132+
return a.(*ssh.Client), nil
133+
}
134+
135+
cli, err := d.sshClient(ctx)
136+
if err != nil {
137+
return nil, err
113138
}
139+
140+
return cli, nil
141+
}
142+
143+
func (d *Dialer) PutClient(cli *ssh.Client) {
144+
d.pool.Put(cli)
145+
}
146+
147+
func (d *Dialer) sshClient(ctx context.Context) (*ssh.Client, error) {
114148
conn, err := d.proxyDial(ctx, "tcp", d.host)
115149
if err != nil {
116150
return nil, err
@@ -120,39 +154,42 @@ func (d *Dialer) SSHClient(ctx context.Context) (*ssh.Client, error) {
120154
if err != nil {
121155
return nil, err
122156
}
123-
cli = ssh.NewClient(con, chans, reqs)
124-
d.sshCli = cli
125-
return cli, nil
157+
return ssh.NewClient(con, chans, reqs), nil
126158
}
127159

128-
func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...string) (net.Conn, error) {
129-
cmd := make([]string, 0, len(args)+1)
130-
cmd = append(cmd, name)
160+
func buildCmd(name string, args ...string) string {
161+
cmds := make([]string, 0, len(args)+1)
162+
cmds = append(cmds, name)
131163
for _, arg := range args {
132-
cmd = append(cmd, strconv.Quote(arg))
164+
cmds = append(cmds, strconv.Quote(arg))
133165
}
134-
return d.commandDialContext(ctx, strings.Join(cmd, " "), 1)
166+
return strings.Join(cmds, " ")
135167
}
136168

137-
func (d *Dialer) commandDialContext(ctx context.Context, cmd string, retry int) (net.Conn, error) {
138-
cli, err := d.SSHClient(ctx)
169+
func (d *Dialer) CommandDialContext(ctx context.Context, name string, args ...string) (net.Conn, error) {
170+
cli, err := d.GetClient(ctx)
139171
if err != nil {
140172
return nil, err
141173
}
142174
sess, err := cli.NewSession()
143175
if err != nil {
176+
if isSSHError(err) {
177+
d.PutClient(cli)
178+
} else {
179+
cli.Close()
180+
}
144181
return nil, err
145182
}
183+
defer d.PutClient(cli)
184+
146185
conn1, conn2 := net.Pipe()
147186
sess.Stdin = conn1
148187
sess.Stdout = conn1
149188
sess.Stderr = os.Stderr
189+
190+
cmd := buildCmd(name, args...)
150191
err = sess.Start(cmd)
151192
if err != nil {
152-
if retry != 0 {
153-
d.Close()
154-
return d.commandDialContext(ctx, cmd, retry-1)
155-
}
156193
return nil, err
157194
}
158195
ctx, cancel := context.WithCancel(ctx)
@@ -180,25 +217,42 @@ func (d *Dialer) commandDialContext(ctx context.Context, cmd string, retry int)
180217
}
181218

182219
func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
183-
return d.dialContext(ctx, network, address, 1)
184-
}
185-
186-
func (d *Dialer) dialContext(ctx context.Context, network, address string, retry int) (net.Conn, error) {
187-
cli, err := d.SSHClient(ctx)
220+
cli, err := d.GetClient(ctx)
188221
if err != nil {
189222
return nil, err
190223
}
224+
191225
conn, err := cli.DialContext(ctx, network, address)
192226
if err != nil {
193-
if ctx.Err() != nil {
194-
return nil, err
227+
if isSSHError(err) {
228+
d.PutClient(cli)
229+
} else {
230+
cli.Close()
195231
}
196-
if retry != 0 {
197-
d.Close()
198-
return d.dialContext(ctx, network, address, retry-1)
232+
return nil, err
233+
}
234+
235+
d.PutClient(cli)
236+
return conn, nil
237+
}
238+
239+
func (d *Dialer) Dial(network, address string) (net.Conn, error) {
240+
cli, err := d.GetClient(context.Background())
241+
if err != nil {
242+
return nil, err
243+
}
244+
245+
conn, err := cli.Dial(network, address)
246+
if err != nil {
247+
if isSSHError(err) {
248+
d.PutClient(cli)
249+
} else {
250+
cli.Close()
199251
}
200252
return nil, err
201253
}
254+
255+
d.PutClient(cli)
202256
return conn, nil
203257
}
204258

@@ -209,21 +263,44 @@ func (d *Dialer) Listen(ctx context.Context, network, address string) (net.Liste
209263
address = net.JoinHostPort("0.0.0.0", port)
210264
}
211265
}
212-
return d.listen(ctx, network, address, 1)
213-
}
214266

215-
func (d *Dialer) listen(ctx context.Context, network, address string, retry int) (net.Listener, error) {
216-
cli, err := d.SSHClient(ctx)
267+
cli, err := d.GetClient(ctx)
217268
if err != nil {
218269
return nil, err
219270
}
271+
220272
listener, err := cli.Listen(network, address)
221273
if err != nil {
222-
if retry != 0 {
223-
d.Close()
224-
return d.listen(ctx, network, address, retry-1)
274+
if isSSHError(err) {
275+
d.PutClient(cli)
276+
} else {
277+
cli.Close()
225278
}
226279
return nil, err
227280
}
281+
282+
listener = &listenerWithCleanup{
283+
Listener: listener,
284+
cleanup: func() {
285+
d.PutClient(cli)
286+
},
287+
}
288+
228289
return listener, nil
229290
}
291+
292+
type listenerWithCleanup struct {
293+
net.Listener
294+
cleanup func()
295+
}
296+
297+
func (l *listenerWithCleanup) Close() error {
298+
err := l.Listener.Close()
299+
l.cleanup()
300+
return err
301+
}
302+
303+
func isSSHError(err error) bool {
304+
msg := err.Error()
305+
return strings.Contains(msg, "ssh: ")
306+
}

0 commit comments

Comments
 (0)