Skip to content

Commit 0b0f26a

Browse files
committed
Implements HTTPS proxy functionality
1 parent 5e00238 commit 0b0f26a

File tree

3 files changed

+1137
-45
lines changed

3 files changed

+1137
-45
lines changed

Diff for: client.go

+122-43
Original file line numberDiff line numberDiff line change
@@ -51,18 +51,34 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
5151
//
5252
// It is safe to call Dialer's methods concurrently.
5353
type Dialer struct {
54+
// The following custom dial functions can be set to establish
55+
// connections to either the backend server or the proxy (if it
56+
// exists). The scheme of the dialed entity (either backend or
57+
// proxy) determines which custom dial function is selected:
58+
// either NetDialTLSContext for HTTPS or NetDialContext/NetDial
59+
// for HTTP. Since the "Proxy" function can determine the scheme
60+
// dynamically, it can make sense to set multiple custom dial
61+
// functions simultaneously.
62+
//
5463
// NetDial specifies the dial function for creating TCP connections. If
5564
// NetDial is nil, net.Dialer DialContext is used.
65+
// If "Proxy" field is also set, this function dials the proxy--not
66+
// the backend server.
5667
NetDial func(network, addr string) (net.Conn, error)
5768

5869
// NetDialContext specifies the dial function for creating TCP connections. If
5970
// NetDialContext is nil, NetDial is used.
71+
// If "Proxy" field is also set, this function dials the proxy--not
72+
// the backend server.
6073
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
6174

6275
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
6376
// NetDialTLSContext is nil, NetDialContext is used.
6477
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
6578
// TLSClientConfig is ignored.
79+
// If "Proxy" field is also set, this function dials the proxy (and performs
80+
// the TLS handshake with the proxy, ignoring TLSClientConfig). In this TLS proxy
81+
// dialing case the TLSClientConfig could still be necessary for TLS to the backend server.
6682
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
6783

6884
// Proxy specifies a function to return a proxy for a given
@@ -73,7 +89,7 @@ type Dialer struct {
7389

7490
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
7591
// If nil, the default configuration is used.
76-
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
92+
// If NetDialTLSContext is set, Dial assumes the TLS handshake
7793
// is done there and TLSClientConfig is ignored.
7894
TLSClientConfig *tls.Config
7995

@@ -244,49 +260,16 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
244260
defer cancel()
245261
}
246262

247-
var netDial netDialerFunc
248-
switch {
249-
case u.Scheme == "https" && d.NetDialTLSContext != nil:
250-
netDial = d.NetDialTLSContext
251-
case d.NetDialContext != nil:
252-
netDial = d.NetDialContext
253-
case d.NetDial != nil:
254-
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
255-
return d.NetDial(net, addr)
256-
}
257-
default:
258-
netDial = (&net.Dialer{}).DialContext
259-
}
260-
261-
// If needed, wrap the dial function to set the connection deadline.
262-
if deadline, ok := ctx.Deadline(); ok {
263-
forwardDial := netDial
264-
netDial = func(ctx context.Context, network, addr string) (net.Conn, error) {
265-
c, err := forwardDial(ctx, network, addr)
266-
if err != nil {
267-
return nil, err
268-
}
269-
err = c.SetDeadline(deadline)
270-
if err != nil {
271-
c.Close()
272-
return nil, err
273-
}
274-
return c, nil
275-
}
276-
}
277-
278-
// If needed, wrap the dial function to connect through a proxy.
263+
var proxyURL *url.URL
279264
if d.Proxy != nil {
280-
proxyURL, err := d.Proxy(req)
265+
proxyURL, err = d.Proxy(req)
281266
if err != nil {
282267
return nil, nil, err
283268
}
284-
if proxyURL != nil {
285-
netDial, err = proxyFromURL(proxyURL, netDial)
286-
if err != nil {
287-
return nil, nil, err
288-
}
289-
}
269+
}
270+
netDial, err := d.netDialFn(ctx, u, proxyURL)
271+
if err != nil {
272+
return nil, nil, err
290273
}
291274

292275
hostPort, hostNoPort := hostPortNoPort(u)
@@ -317,9 +300,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
317300
}
318301
}()
319302

320-
if u.Scheme == "https" && d.NetDialTLSContext == nil {
321-
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
322-
303+
// Do TLS handshake over "netConn" connection if necessary.
304+
if d.needsTLSHandshake(u, proxyURL) {
323305
cfg := cloneTLSConfig(d.TLSClientConfig)
324306
if cfg.ServerName == "" {
325307
cfg.ServerName = hostNoPort
@@ -415,6 +397,103 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
415397
return conn, resp, nil
416398
}
417399

400+
// Returns the dial function to establish the connection to either the backend
401+
// server or the proxy (if it exists). Instead returns an error if one occurred.
402+
func (d *Dialer) netDialFn(ctx context.Context, backendURL *url.URL, proxyURL *url.URL) (netDialerFunc, error) {
403+
netDial := d.netDialFromScheme(backendURL.Scheme)
404+
if proxyURL != nil {
405+
netDial = d.netDialFromScheme(proxyURL.Scheme)
406+
// Wrap proxy dial function to perform TLS handshake if necessary.
407+
if proxyURL.Scheme == "https" && d.NetDialTLSContext == nil {
408+
netDial = netDialWithTLSHandshake(netDial, d.TLSClientConfig, proxyURL)
409+
}
410+
}
411+
// If needed, wrap the dial function to set the connection deadline.
412+
if deadline, ok := ctx.Deadline(); ok {
413+
netDial = netDialWithDeadline(netDial, deadline)
414+
}
415+
// Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth.
416+
if proxyURL != nil {
417+
return proxyFromURL(proxyURL, netDial)
418+
}
419+
return netDial, nil
420+
}
421+
422+
// Returns function to create the connection depending on the Dialer's
423+
// custom dialing functions and the passed scheme of entity connecting to.
424+
func (d *Dialer) netDialFromScheme(scheme string) netDialerFunc {
425+
var netDial netDialerFunc
426+
switch {
427+
case scheme == "https" && d.NetDialTLSContext != nil:
428+
netDial = d.NetDialTLSContext
429+
case d.NetDialContext != nil:
430+
netDial = d.NetDialContext
431+
case d.NetDial != nil:
432+
netDial = func(ctx context.Context, net, addr string) (net.Conn, error) {
433+
return d.NetDial(net, addr)
434+
}
435+
default:
436+
netDial = (&net.Dialer{}).DialContext
437+
}
438+
return netDial
439+
}
440+
441+
// Returns true if a TLS handshake needs to be performed to the backend server
442+
// after the connection has been established (since some dialing functions *also*
443+
// perform TLS handshake).
444+
func (d *Dialer) needsTLSHandshake(backendURL *url.URL, proxyURL *url.URL) bool {
445+
if backendURL.Scheme != "https" {
446+
return false
447+
}
448+
// If a proxy exists, we will always need to do a TLS handshake.
449+
if proxyURL != nil {
450+
return true
451+
}
452+
// Otherwise, we will need to do a TLS handshake to the backend only
453+
// if it has not already been performed by NetDialTLSContext.
454+
return d.NetDialTLSContext == nil
455+
}
456+
457+
// Returns wrapped "netDial" function, performing TLS handshake after connecting.
458+
func netDialWithTLSHandshake(netDial netDialerFunc, tlsConfig *tls.Config, u *url.URL) netDialerFunc {
459+
return func(ctx context.Context, unused, addr string) (net.Conn, error) {
460+
// Creates TCP connection to addr using passed "netDial" function.
461+
conn, err := netDial(ctx, "tcp", addr)
462+
if err != nil {
463+
return nil, err
464+
}
465+
cfg := cloneTLSConfig(tlsConfig)
466+
if cfg.ServerName == "" {
467+
_, hostNoPort := hostPortNoPort(u)
468+
cfg.ServerName = hostNoPort
469+
}
470+
tlsConn := tls.Client(conn, cfg)
471+
// Do the TLS handshake using TLSConfig over the wrapped connection.
472+
err = doHandshake(ctx, tlsConn, cfg)
473+
if err != nil {
474+
tlsConn.Close()
475+
return nil, err
476+
}
477+
return tlsConn, nil
478+
}
479+
}
480+
481+
// Returns wrapped "netDial" function, setting passed deadline.
482+
func netDialWithDeadline(netDial netDialerFunc, deadline time.Time) netDialerFunc {
483+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
484+
c, err := netDial(ctx, network, addr)
485+
if err != nil {
486+
return nil, err
487+
}
488+
err = c.SetDeadline(deadline)
489+
if err != nil {
490+
c.Close()
491+
return nil, err
492+
}
493+
return c, nil
494+
}
495+
}
496+
418497
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
419498
if cfg == nil {
420499
return &tls.Config{}

0 commit comments

Comments
 (0)