Skip to content

Commit 9111bb8

Browse files
authored
Dialer: add optional method NetDialTLSContext (#746)
Fixes issue: #745 With the previous interface, NetDial and NetDialContext were used for both TLS and non-TLS TCP connections, and afterwards TLSClientConfig was used to do the TLS handshake. While this API works for most cases, it prevents from using more advance authentication methods during the TLS handshake, as this is out of the control of the user. This commits introduces another a new dial method, NetDialTLSContext, which is used when dialing for TLS/TCP. The code then assumes that the handshake is done there and TLSClientConfig is not used. This API change is fully backwards compatible and it better aligns with net/http.Transport API, which has these two dial flavors. See: https://pkg.go.dev/net/http#Transport Signed-off-by: Lluis Campos <[email protected]>
1 parent 2f25f78 commit 9111bb8

File tree

2 files changed

+215
-8
lines changed

2 files changed

+215
-8
lines changed

Diff for: client.go

+37-8
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,15 @@ type Dialer struct {
5656
NetDial func(network, addr string) (net.Conn, error)
5757

5858
// NetDialContext specifies the dial function for creating TCP connections. If
59-
// NetDialContext is nil, net.DialContext is used.
59+
// NetDialContext is nil, NetDial is used.
6060
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
6161

62+
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
63+
// NetDialTLSContext is nil, NetDialContext is used.
64+
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
65+
// TLSClientConfig is ignored.
66+
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
67+
6268
// Proxy specifies a function to return a proxy for a given
6369
// Request. If the function returns a non-nil error, the
6470
// request is aborted with the provided error.
@@ -67,6 +73,8 @@ type Dialer struct {
6773

6874
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
6975
// If nil, the default configuration is used.
76+
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
77+
// is done there and TLSClientConfig is ignored.
7078
TLSClientConfig *tls.Config
7179

7280
// HandshakeTimeout specifies the duration for the handshake to complete.
@@ -239,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
239247
// Get network dial function.
240248
var netDial func(network, add string) (net.Conn, error)
241249

242-
if d.NetDialContext != nil {
243-
netDial = func(network, addr string) (net.Conn, error) {
244-
return d.NetDialContext(ctx, network, addr)
250+
switch u.Scheme {
251+
case "http":
252+
if d.NetDialContext != nil {
253+
netDial = func(network, addr string) (net.Conn, error) {
254+
return d.NetDialContext(ctx, network, addr)
255+
}
256+
} else if d.NetDial != nil {
257+
netDial = d.NetDial
245258
}
246-
} else if d.NetDial != nil {
247-
netDial = d.NetDial
248-
} else {
259+
case "https":
260+
if d.NetDialTLSContext != nil {
261+
netDial = func(network, addr string) (net.Conn, error) {
262+
return d.NetDialTLSContext(ctx, network, addr)
263+
}
264+
} else if d.NetDialContext != nil {
265+
netDial = func(network, addr string) (net.Conn, error) {
266+
return d.NetDialContext(ctx, network, addr)
267+
}
268+
} else if d.NetDial != nil {
269+
netDial = d.NetDial
270+
}
271+
default:
272+
return nil, nil, errMalformedURL
273+
}
274+
275+
if netDial == nil {
249276
netDialer := &net.Dialer{}
250277
netDial = func(network, addr string) (net.Conn, error) {
251278
return netDialer.DialContext(ctx, network, addr)
@@ -306,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
306333
}
307334
}()
308335

309-
if u.Scheme == "https" {
336+
if u.Scheme == "https" && d.NetDialTLSContext == nil {
337+
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
338+
310339
cfg := cloneTLSConfig(d.TLSClientConfig)
311340
if cfg.ServerName == "" {
312341
cfg.ServerName = hostNoPort

Diff for: client_server_test.go

+178
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"crypto/x509"
1212
"encoding/base64"
1313
"encoding/binary"
14+
"errors"
1415
"fmt"
1516
"io"
1617
"io/ioutil"
@@ -920,3 +921,180 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
920921
defer ws.Close()
921922
sendRecv(t, ws)
922923
}
924+
925+
// TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
926+
func TestNetDialConnect(t *testing.T) {
927+
928+
upgrader := Upgrader{}
929+
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
930+
if IsWebSocketUpgrade(r) {
931+
c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
932+
if err != nil {
933+
t.Fatal(err)
934+
}
935+
c.Close()
936+
} else {
937+
w.Header().Set("X-Test-Host", r.Host)
938+
}
939+
})
940+
941+
server := httptest.NewServer(handler)
942+
defer server.Close()
943+
944+
tlsServer := httptest.NewTLSServer(handler)
945+
defer tlsServer.Close()
946+
947+
testUrls := map[*httptest.Server]string{
948+
server: "ws://" + server.Listener.Addr().String() + "/",
949+
tlsServer: "wss://" + tlsServer.Listener.Addr().String() + "/",
950+
}
951+
952+
cas := rootCAs(t, tlsServer)
953+
tlsConfig := &tls.Config{
954+
RootCAs: cas,
955+
ServerName: "example.com",
956+
InsecureSkipVerify: false,
957+
}
958+
959+
tests := []struct {
960+
name string
961+
server *httptest.Server // server to use
962+
netDial func(network, addr string) (net.Conn, error)
963+
netDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
964+
netDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
965+
tlsClientConfig *tls.Config
966+
}{
967+
968+
{
969+
name: "HTTP server, all NetDial* defined, shall use NetDialContext",
970+
server: server,
971+
netDial: func(network, addr string) (net.Conn, error) {
972+
return nil, errors.New("NetDial should not be called")
973+
},
974+
netDialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
975+
return net.Dial(network, addr)
976+
},
977+
netDialTLSContext: func(_ context.Context, network, addr string) (net.Conn, error) {
978+
return nil, errors.New("NetDialTLSContext should not be called")
979+
},
980+
tlsClientConfig: nil,
981+
},
982+
{
983+
name: "HTTP server, all NetDial* undefined",
984+
server: server,
985+
netDial: nil,
986+
netDialContext: nil,
987+
netDialTLSContext: nil,
988+
tlsClientConfig: nil,
989+
},
990+
{
991+
name: "HTTP server, NetDialContext undefined, shall fallback to NetDial",
992+
server: server,
993+
netDial: func(network, addr string) (net.Conn, error) {
994+
return net.Dial(network, addr)
995+
},
996+
netDialContext: nil,
997+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
998+
return nil, errors.New("NetDialTLSContext should not be called")
999+
},
1000+
tlsClientConfig: nil,
1001+
},
1002+
{
1003+
name: "HTTPS server, all NetDial* defined, shall use NetDialTLSContext",
1004+
server: tlsServer,
1005+
netDial: func(network, addr string) (net.Conn, error) {
1006+
return nil, errors.New("NetDial should not be called")
1007+
},
1008+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1009+
return nil, errors.New("NetDialContext should not be called")
1010+
},
1011+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1012+
netConn, err := net.Dial(network, addr)
1013+
if err != nil {
1014+
return nil, err
1015+
}
1016+
tlsConn := tls.Client(netConn, tlsConfig)
1017+
err = tlsConn.Handshake()
1018+
if err != nil {
1019+
return nil, err
1020+
}
1021+
return tlsConn, nil
1022+
},
1023+
tlsClientConfig: nil,
1024+
},
1025+
{
1026+
name: "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake",
1027+
server: tlsServer,
1028+
netDial: func(network, addr string) (net.Conn, error) {
1029+
return nil, errors.New("NetDial should not be called")
1030+
},
1031+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1032+
return net.Dial(network, addr)
1033+
},
1034+
netDialTLSContext: nil,
1035+
tlsClientConfig: tlsConfig,
1036+
},
1037+
{
1038+
name: "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake",
1039+
server: tlsServer,
1040+
netDial: func(network, addr string) (net.Conn, error) {
1041+
return net.Dial(network, addr)
1042+
},
1043+
netDialContext: nil,
1044+
netDialTLSContext: nil,
1045+
tlsClientConfig: tlsConfig,
1046+
},
1047+
{
1048+
name: "HTTPS server, all NetDial* undefined",
1049+
server: tlsServer,
1050+
netDial: nil,
1051+
netDialContext: nil,
1052+
netDialTLSContext: nil,
1053+
tlsClientConfig: tlsConfig,
1054+
},
1055+
{
1056+
name: "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake",
1057+
server: tlsServer,
1058+
netDial: func(network, addr string) (net.Conn, error) {
1059+
return nil, errors.New("NetDial should not be called")
1060+
},
1061+
netDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1062+
return nil, errors.New("NetDialContext should not be called")
1063+
},
1064+
netDialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
1065+
netConn, err := net.Dial(network, addr)
1066+
if err != nil {
1067+
return nil, err
1068+
}
1069+
tlsConn := tls.Client(netConn, tlsConfig)
1070+
err = tlsConn.Handshake()
1071+
if err != nil {
1072+
return nil, err
1073+
}
1074+
return tlsConn, nil
1075+
},
1076+
tlsClientConfig: &tls.Config{
1077+
RootCAs: nil,
1078+
ServerName: "badserver.com",
1079+
InsecureSkipVerify: false,
1080+
},
1081+
},
1082+
}
1083+
1084+
for _, tc := range tests {
1085+
dialer := Dialer{
1086+
NetDial: tc.netDial,
1087+
NetDialContext: tc.netDialContext,
1088+
NetDialTLSContext: tc.netDialTLSContext,
1089+
TLSClientConfig: tc.tlsClientConfig,
1090+
}
1091+
1092+
// Test websocket dial
1093+
c, _, err := dialer.Dial(testUrls[tc.server], nil)
1094+
if err != nil {
1095+
t.Errorf("FAILED %s, err: %s", tc.name, err.Error())
1096+
} else {
1097+
c.Close()
1098+
}
1099+
}
1100+
}

0 commit comments

Comments
 (0)