@@ -11,6 +11,7 @@ import (
11
11
"crypto/x509"
12
12
"encoding/base64"
13
13
"encoding/binary"
14
+ "errors"
14
15
"fmt"
15
16
"io"
16
17
"io/ioutil"
@@ -920,3 +921,180 @@ func TestEmptyTracingDialWithContext(t *testing.T) {
920
921
defer ws .Close ()
921
922
sendRecv (t , ws )
922
923
}
924
+
925
+ // TestNetDialConnect tests selection of dial method between NetDial, NetDialContext, NetDialTLS or NetDialTLSContext
926
+ func TestNetDialConnect (t * testing.T ) {
927
+
928
+ upgrader := Upgrader {}
929
+ handler := http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
930
+ if IsWebSocketUpgrade (r ) {
931
+ c , err := upgrader .Upgrade (w , r , http.Header {"X-Test-Host" : {r .Host }})
932
+ if err != nil {
933
+ t .Fatal (err )
934
+ }
935
+ c .Close ()
936
+ } else {
937
+ w .Header ().Set ("X-Test-Host" , r .Host )
938
+ }
939
+ })
940
+
941
+ server := httptest .NewServer (handler )
942
+ defer server .Close ()
943
+
944
+ tlsServer := httptest .NewTLSServer (handler )
945
+ defer tlsServer .Close ()
946
+
947
+ testUrls := map [* httptest.Server ]string {
948
+ server : "ws://" + server .Listener .Addr ().String () + "/" ,
949
+ tlsServer : "wss://" + tlsServer .Listener .Addr ().String () + "/" ,
950
+ }
951
+
952
+ cas := rootCAs (t , tlsServer )
953
+ tlsConfig := & tls.Config {
954
+ RootCAs : cas ,
955
+ ServerName : "example.com" ,
956
+ InsecureSkipVerify : false ,
957
+ }
958
+
959
+ tests := []struct {
960
+ name string
961
+ server * httptest.Server // server to use
962
+ netDial func (network , addr string ) (net.Conn , error )
963
+ netDialContext func (ctx context.Context , network , addr string ) (net.Conn , error )
964
+ netDialTLSContext func (ctx context.Context , network , addr string ) (net.Conn , error )
965
+ tlsClientConfig * tls.Config
966
+ }{
967
+
968
+ {
969
+ name : "HTTP server, all NetDial* defined, shall use NetDialContext" ,
970
+ server : server ,
971
+ netDial : func (network , addr string ) (net.Conn , error ) {
972
+ return nil , errors .New ("NetDial should not be called" )
973
+ },
974
+ netDialContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
975
+ return net .Dial (network , addr )
976
+ },
977
+ netDialTLSContext : func (_ context.Context , network , addr string ) (net.Conn , error ) {
978
+ return nil , errors .New ("NetDialTLSContext should not be called" )
979
+ },
980
+ tlsClientConfig : nil ,
981
+ },
982
+ {
983
+ name : "HTTP server, all NetDial* undefined" ,
984
+ server : server ,
985
+ netDial : nil ,
986
+ netDialContext : nil ,
987
+ netDialTLSContext : nil ,
988
+ tlsClientConfig : nil ,
989
+ },
990
+ {
991
+ name : "HTTP server, NetDialContext undefined, shall fallback to NetDial" ,
992
+ server : server ,
993
+ netDial : func (network , addr string ) (net.Conn , error ) {
994
+ return net .Dial (network , addr )
995
+ },
996
+ netDialContext : nil ,
997
+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
998
+ return nil , errors .New ("NetDialTLSContext should not be called" )
999
+ },
1000
+ tlsClientConfig : nil ,
1001
+ },
1002
+ {
1003
+ name : "HTTPS server, all NetDial* defined, shall use NetDialTLSContext" ,
1004
+ server : tlsServer ,
1005
+ netDial : func (network , addr string ) (net.Conn , error ) {
1006
+ return nil , errors .New ("NetDial should not be called" )
1007
+ },
1008
+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1009
+ return nil , errors .New ("NetDialContext should not be called" )
1010
+ },
1011
+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1012
+ netConn , err := net .Dial (network , addr )
1013
+ if err != nil {
1014
+ return nil , err
1015
+ }
1016
+ tlsConn := tls .Client (netConn , tlsConfig )
1017
+ err = tlsConn .Handshake ()
1018
+ if err != nil {
1019
+ return nil , err
1020
+ }
1021
+ return tlsConn , nil
1022
+ },
1023
+ tlsClientConfig : nil ,
1024
+ },
1025
+ {
1026
+ name : "HTTPS server, NetDialTLSContext undefined, shall fallback to NetDialContext and do handshake" ,
1027
+ server : tlsServer ,
1028
+ netDial : func (network , addr string ) (net.Conn , error ) {
1029
+ return nil , errors .New ("NetDial should not be called" )
1030
+ },
1031
+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1032
+ return net .Dial (network , addr )
1033
+ },
1034
+ netDialTLSContext : nil ,
1035
+ tlsClientConfig : tlsConfig ,
1036
+ },
1037
+ {
1038
+ name : "HTTPS server, NetDialTLSContext and NetDialContext undefined, shall fallback to NetDial and do handshake" ,
1039
+ server : tlsServer ,
1040
+ netDial : func (network , addr string ) (net.Conn , error ) {
1041
+ return net .Dial (network , addr )
1042
+ },
1043
+ netDialContext : nil ,
1044
+ netDialTLSContext : nil ,
1045
+ tlsClientConfig : tlsConfig ,
1046
+ },
1047
+ {
1048
+ name : "HTTPS server, all NetDial* undefined" ,
1049
+ server : tlsServer ,
1050
+ netDial : nil ,
1051
+ netDialContext : nil ,
1052
+ netDialTLSContext : nil ,
1053
+ tlsClientConfig : tlsConfig ,
1054
+ },
1055
+ {
1056
+ name : "HTTPS server, all NetDialTLSContext defined, dummy TlsClientConfig defined, shall not do handshake" ,
1057
+ server : tlsServer ,
1058
+ netDial : func (network , addr string ) (net.Conn , error ) {
1059
+ return nil , errors .New ("NetDial should not be called" )
1060
+ },
1061
+ netDialContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1062
+ return nil , errors .New ("NetDialContext should not be called" )
1063
+ },
1064
+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
1065
+ netConn , err := net .Dial (network , addr )
1066
+ if err != nil {
1067
+ return nil , err
1068
+ }
1069
+ tlsConn := tls .Client (netConn , tlsConfig )
1070
+ err = tlsConn .Handshake ()
1071
+ if err != nil {
1072
+ return nil , err
1073
+ }
1074
+ return tlsConn , nil
1075
+ },
1076
+ tlsClientConfig : & tls.Config {
1077
+ RootCAs : nil ,
1078
+ ServerName : "badserver.com" ,
1079
+ InsecureSkipVerify : false ,
1080
+ },
1081
+ },
1082
+ }
1083
+
1084
+ for _ , tc := range tests {
1085
+ dialer := Dialer {
1086
+ NetDial : tc .netDial ,
1087
+ NetDialContext : tc .netDialContext ,
1088
+ NetDialTLSContext : tc .netDialTLSContext ,
1089
+ TLSClientConfig : tc .tlsClientConfig ,
1090
+ }
1091
+
1092
+ // Test websocket dial
1093
+ c , _ , err := dialer .Dial (testUrls [tc .server ], nil )
1094
+ if err != nil {
1095
+ t .Errorf ("FAILED %s, err: %s" , tc .name , err .Error ())
1096
+ } else {
1097
+ c .Close ()
1098
+ }
1099
+ }
1100
+ }
0 commit comments